Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions pithtrain/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pithtrain.dualpipe.modeling import decoder_layer_backward, decoder_layer_forward
from pithtrain.dualpipe.utils import run_backward
from pithtrain.layers.factory import ModelImplMode, get_linear_cls
from pithtrain.layers.group_linear import GroupLinearFunc
from pithtrain.models.interface import ForwardAttnOutput
from pithtrain.modules.load_balance import MoELoadBalanceLossInjector, MoELoadBalanceLossTracker
from pithtrain.operators.clamped_swiglu import clamped_swiglu
Expand Down Expand Up @@ -180,12 +181,15 @@ def __init__(
self.down_proj = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
self.down_proj_bias = nn.Parameter(torch.zeros(num_experts, hidden_size))

def _grouped_mm(
self, x: torch.Tensor, weight: nn.Parameter, offs: torch.Tensor
def _group_linear(
self,
x: torch.Tensor,
weight: nn.Parameter,
offs: torch.Tensor,
) -> torch.Tensor:
if x.shape[0] == 0:
return x @ weight[0].transpose(-2, -1)
return F.grouped_mm(x, weight.transpose(-2, -1), offs=offs)
return GroupLinearFunc.apply(x, weight, offs)

def forward(
self,
Expand All @@ -209,11 +213,11 @@ def forward(
right=True,
).clamp_(max=self.num_experts - 1)

gate_up = self._grouped_mm(x, self.gate_up_proj, grouped_mm_offs)
gate_up = self._group_linear(x, self.gate_up_proj, grouped_mm_offs)
gate_up = gate_up + self.gate_up_proj_bias[group_ids]
activated = clamped_swiglu(gate_up, SWIGLU_ALPHA, self.swiglu_limit)

out = self._grouped_mm(activated, self.down_proj, grouped_mm_offs)
out = self._group_linear(activated, self.down_proj, grouped_mm_offs)
out = out + self.down_proj_bias[group_ids]
return out

Expand Down
126 changes: 122 additions & 4 deletions tests/test_grouped_linear_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F

from pithtrain.layers.group_linear import GroupLinear
from pithtrain.operators.token_scatter import scatter_for_grouped_gemm
from pithtrain.operators.token_scatter import _GEMM_ALLOC_ALIGNMENT, scatter_for_grouped_gemm


def reference_grouped_linear_forward(
Expand Down Expand Up @@ -401,14 +401,22 @@ def test_scatter_for_grouped_gemm():
assert offs_new.shape == (num_groups,), f"offs shape {offs_new.shape} != ({num_groups},)"
assert offs_new.dtype == torch.int32

# Verify output shape (exactly sized, no over-allocation)
# Verify output shape: over-allocated up to _GEMM_ALLOC_ALIGNMENT for
# allocator-friendliness, and rows in [offs[-1], out.shape[0]) zeroed
# by the scatter kernel so downstream grouped_mm is safe.
m_padded = offs_new[-1].item()
assert out_new.shape[0] == m_padded, (
f"output rows {out_new.shape[0]} != offs[-1]={m_padded}"
assert out_new.shape[0] >= m_padded, f"output rows {out_new.shape[0]} < offs[-1]={m_padded}"
assert out_new.shape[0] % _GEMM_ALLOC_ALIGNMENT == 0, (
f"output rows {out_new.shape[0]} not aligned to {_GEMM_ALLOC_ALIGNMENT}"
)
assert out_new.shape[1] == hidden_size, (
f"output cols {out_new.shape[1]} != hidden_size={hidden_size}"
)
if out_new.shape[0] > m_padded:
tail = out_new[m_padded:]
assert torch.all(tail == 0), (
f"over-allocated tail [{m_padded}:{out_new.shape[0]}) must be zero"
)

# Verify reverse_shuffle_idxs
assert reverse_new.shape == (m,)
Expand Down Expand Up @@ -634,3 +642,113 @@ def test_group_linear_weight_grad_store():
finally:
WeightGradStore.enabled = False
WeightGradStore.clear()


def test_gpt_oss_experts_weight_grad_store_matches_direct():
"""
End-to-end sanity check: GptOssExperts (now backed by GroupLinearFunc for
the expert GEMMs) produces the same input / weight / bias gradients whether
WeightGradStore is enabled or disabled.

Only the expert-weight wgrad is routed through WeightGradStore; the bias
add lives outside GroupLinearFunc so its grad is computed eagerly via
autograd on both paths.
"""
from pithtrain.dualpipe.utils import WeightGradStore
from pithtrain.models.gpt_oss import GptOssExperts

device = torch.device("cuda")
dtype = torch.bfloat16

num_experts = 4
hidden_size = 128
intermediate_size = 256
swiglu_limit = 7.0

# Random routing with one empty group to exercise the zero-token branch.
group_sizes = [8, 0, 12, 4]
M_total = sum(group_sizes)
ks = list(group_sizes)
offs = torch.tensor(group_sizes, device=device).cumsum(0).to(torch.int32)

torch.manual_seed(123)
x_raw = torch.randn(M_total, hidden_size, device=device, dtype=dtype)

experts_ref = GptOssExperts(num_experts, hidden_size, intermediate_size, swiglu_limit)
experts_ref = experts_ref.to(device=device, dtype=dtype)
with torch.no_grad():
experts_ref.gate_up_proj.normal_(std=0.02)
experts_ref.down_proj.normal_(std=0.02)
experts_ref.gate_up_proj_bias.normal_(std=0.02)
experts_ref.down_proj_bias.normal_(std=0.02)

experts_def = GptOssExperts(num_experts, hidden_size, intermediate_size, swiglu_limit)
experts_def = experts_def.to(device=device, dtype=dtype)
experts_def.load_state_dict(experts_ref.state_dict())

# Run 1: WeightGradStore disabled.
x_ref = x_raw.detach().clone().requires_grad_(True)
out_ref = experts_ref(x_ref, offs, ks=ks)
grad = torch.randn_like(out_ref)
out_ref.backward(grad)

# Run 2: WeightGradStore enabled (weight wgrads deferred, bias bgrads eager).
x_def = x_raw.detach().clone().requires_grad_(True)
WeightGradStore.enabled = True
try:
out_def = experts_def(x_def, offs, ks=ks)

fwd_diff = (out_def - out_ref).abs().max().item()
assert torch.allclose(out_def, out_ref, rtol=1e-3, atol=1e-3), (
f"forward mismatch between store on/off, max diff {fwd_diff}"
)

out_def.backward(grad)

# Expert weights go through GroupLinearFunc's deferred wgrad path.
for p_def, name in [
(experts_def.gate_up_proj, "gate_up_proj"),
(experts_def.down_proj, "down_proj"),
]:
assert p_def.grad is None, f"[deferred] {name} grad should be deferred"

# Biases are added outside GroupLinearFunc; autograd populates them eagerly.
for p_def, name in [
(experts_def.gate_up_proj_bias, "gate_up_proj_bias"),
(experts_def.down_proj_bias, "down_proj_bias"),
]:
assert p_def.grad is not None, f"[deferred] {name} grad should be eager"

in_diff = (x_def.grad - x_ref.grad).abs().max().item()
assert torch.allclose(x_def.grad, x_ref.grad, rtol=1e-3, atol=1e-3), (
f"input grad mismatch, max diff {in_diff}"
)

# bias grads come from autograd's scatter-add (CUDA index_add_ is
# non-deterministic, so allow ~5% slack on bf16 accumulation).
for p_ref, p_def, name in [
(experts_ref.gate_up_proj_bias, experts_def.gate_up_proj_bias, "gate_up_proj_bias"),
(experts_ref.down_proj_bias, experts_def.down_proj_bias, "down_proj_bias"),
]:
ref_max = p_ref.grad.abs().max().item()
diff = (p_def.grad - p_ref.grad).abs().max().item()
assert diff <= max(0.1, ref_max * 0.05), (
f"{name} grad mismatch, max diff {diff} (max |ref| {ref_max})"
)

WeightGradStore.flush()
WeightGradStore.pop()

# grouped_mm is deterministic — weight grads must match tightly.
for p_ref, p_def, name in [
(experts_ref.gate_up_proj, experts_def.gate_up_proj, "gate_up_proj"),
(experts_ref.down_proj, experts_def.down_proj, "down_proj"),
]:
assert p_def.grad is not None, f"[deferred] {name} grad missing after pop"
diff = (p_def.grad - p_ref.grad).abs().max().item()
assert torch.allclose(p_def.grad, p_ref.grad, rtol=1e-3, atol=1e-3), (
f"{name} grad mismatch (deferred vs direct), max diff {diff}"
)
finally:
WeightGradStore.enabled = False
WeightGradStore.clear()
Loading