diff --git a/pithtrain/models/gpt_oss.py b/pithtrain/models/gpt_oss.py index d7bfc6a..81981ee 100644 --- a/pithtrain/models/gpt_oss.py +++ b/pithtrain/models/gpt_oss.py @@ -20,6 +20,7 @@ from pithtrain.dualpipe.modeling import decoder_layer_backward, decoder_layer_forward from pithtrain.dualpipe.utils import run_backward from pithtrain.layers.factory import ModelImplMode, get_linear_cls +from pithtrain.layers.group_linear import GroupLinearFunc from pithtrain.models.interface import ForwardAttnOutput from pithtrain.modules.load_balance import MoELoadBalanceLossInjector, MoELoadBalanceLossTracker from pithtrain.operators.clamped_swiglu import clamped_swiglu @@ -180,12 +181,15 @@ def __init__( self.down_proj = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.down_proj_bias = nn.Parameter(torch.zeros(num_experts, hidden_size)) - def _grouped_mm( - self, x: torch.Tensor, weight: nn.Parameter, offs: torch.Tensor + def _group_linear( + self, + x: torch.Tensor, + weight: nn.Parameter, + offs: torch.Tensor, ) -> torch.Tensor: if x.shape[0] == 0: return x @ weight[0].transpose(-2, -1) - return F.grouped_mm(x, weight.transpose(-2, -1), offs=offs) + return GroupLinearFunc.apply(x, weight, offs) def forward( self, @@ -209,11 +213,11 @@ def forward( right=True, ).clamp_(max=self.num_experts - 1) - gate_up = self._grouped_mm(x, self.gate_up_proj, grouped_mm_offs) + gate_up = self._group_linear(x, self.gate_up_proj, grouped_mm_offs) gate_up = gate_up + self.gate_up_proj_bias[group_ids] activated = clamped_swiglu(gate_up, SWIGLU_ALPHA, self.swiglu_limit) - out = self._grouped_mm(activated, self.down_proj, grouped_mm_offs) + out = self._group_linear(activated, self.down_proj, grouped_mm_offs) out = out + self.down_proj_bias[group_ids] return out diff --git a/tests/test_grouped_linear_correctness.py b/tests/test_grouped_linear_correctness.py index e429089..5275506 100644 --- a/tests/test_grouped_linear_correctness.py +++ b/tests/test_grouped_linear_correctness.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from pithtrain.layers.group_linear import GroupLinear -from pithtrain.operators.token_scatter import scatter_for_grouped_gemm +from pithtrain.operators.token_scatter import _GEMM_ALLOC_ALIGNMENT, scatter_for_grouped_gemm def reference_grouped_linear_forward( @@ -401,14 +401,22 @@ def test_scatter_for_grouped_gemm(): assert offs_new.shape == (num_groups,), f"offs shape {offs_new.shape} != ({num_groups},)" assert offs_new.dtype == torch.int32 - # Verify output shape (exactly sized, no over-allocation) + # Verify output shape: over-allocated up to _GEMM_ALLOC_ALIGNMENT for + # allocator-friendliness, and rows in [offs[-1], out.shape[0]) zeroed + # by the scatter kernel so downstream grouped_mm is safe. m_padded = offs_new[-1].item() - assert out_new.shape[0] == m_padded, ( - f"output rows {out_new.shape[0]} != offs[-1]={m_padded}" + assert out_new.shape[0] >= m_padded, f"output rows {out_new.shape[0]} < offs[-1]={m_padded}" + assert out_new.shape[0] % _GEMM_ALLOC_ALIGNMENT == 0, ( + f"output rows {out_new.shape[0]} not aligned to {_GEMM_ALLOC_ALIGNMENT}" ) assert out_new.shape[1] == hidden_size, ( f"output cols {out_new.shape[1]} != hidden_size={hidden_size}" ) + if out_new.shape[0] > m_padded: + tail = out_new[m_padded:] + assert torch.all(tail == 0), ( + f"over-allocated tail [{m_padded}:{out_new.shape[0]}) must be zero" + ) # Verify reverse_shuffle_idxs assert reverse_new.shape == (m,) @@ -634,3 +642,113 @@ def test_group_linear_weight_grad_store(): finally: WeightGradStore.enabled = False WeightGradStore.clear() + + +def test_gpt_oss_experts_weight_grad_store_matches_direct(): + """ + End-to-end sanity check: GptOssExperts (now backed by GroupLinearFunc for + the expert GEMMs) produces the same input / weight / bias gradients whether + WeightGradStore is enabled or disabled. + + Only the expert-weight wgrad is routed through WeightGradStore; the bias + add lives outside GroupLinearFunc so its grad is computed eagerly via + autograd on both paths. + """ + from pithtrain.dualpipe.utils import WeightGradStore + from pithtrain.models.gpt_oss import GptOssExperts + + device = torch.device("cuda") + dtype = torch.bfloat16 + + num_experts = 4 + hidden_size = 128 + intermediate_size = 256 + swiglu_limit = 7.0 + + # Random routing with one empty group to exercise the zero-token branch. + group_sizes = [8, 0, 12, 4] + M_total = sum(group_sizes) + ks = list(group_sizes) + offs = torch.tensor(group_sizes, device=device).cumsum(0).to(torch.int32) + + torch.manual_seed(123) + x_raw = torch.randn(M_total, hidden_size, device=device, dtype=dtype) + + experts_ref = GptOssExperts(num_experts, hidden_size, intermediate_size, swiglu_limit) + experts_ref = experts_ref.to(device=device, dtype=dtype) + with torch.no_grad(): + experts_ref.gate_up_proj.normal_(std=0.02) + experts_ref.down_proj.normal_(std=0.02) + experts_ref.gate_up_proj_bias.normal_(std=0.02) + experts_ref.down_proj_bias.normal_(std=0.02) + + experts_def = GptOssExperts(num_experts, hidden_size, intermediate_size, swiglu_limit) + experts_def = experts_def.to(device=device, dtype=dtype) + experts_def.load_state_dict(experts_ref.state_dict()) + + # Run 1: WeightGradStore disabled. + x_ref = x_raw.detach().clone().requires_grad_(True) + out_ref = experts_ref(x_ref, offs, ks=ks) + grad = torch.randn_like(out_ref) + out_ref.backward(grad) + + # Run 2: WeightGradStore enabled (weight wgrads deferred, bias bgrads eager). + x_def = x_raw.detach().clone().requires_grad_(True) + WeightGradStore.enabled = True + try: + out_def = experts_def(x_def, offs, ks=ks) + + fwd_diff = (out_def - out_ref).abs().max().item() + assert torch.allclose(out_def, out_ref, rtol=1e-3, atol=1e-3), ( + f"forward mismatch between store on/off, max diff {fwd_diff}" + ) + + out_def.backward(grad) + + # Expert weights go through GroupLinearFunc's deferred wgrad path. + for p_def, name in [ + (experts_def.gate_up_proj, "gate_up_proj"), + (experts_def.down_proj, "down_proj"), + ]: + assert p_def.grad is None, f"[deferred] {name} grad should be deferred" + + # Biases are added outside GroupLinearFunc; autograd populates them eagerly. + for p_def, name in [ + (experts_def.gate_up_proj_bias, "gate_up_proj_bias"), + (experts_def.down_proj_bias, "down_proj_bias"), + ]: + assert p_def.grad is not None, f"[deferred] {name} grad should be eager" + + in_diff = (x_def.grad - x_ref.grad).abs().max().item() + assert torch.allclose(x_def.grad, x_ref.grad, rtol=1e-3, atol=1e-3), ( + f"input grad mismatch, max diff {in_diff}" + ) + + # bias grads come from autograd's scatter-add (CUDA index_add_ is + # non-deterministic, so allow ~5% slack on bf16 accumulation). + for p_ref, p_def, name in [ + (experts_ref.gate_up_proj_bias, experts_def.gate_up_proj_bias, "gate_up_proj_bias"), + (experts_ref.down_proj_bias, experts_def.down_proj_bias, "down_proj_bias"), + ]: + ref_max = p_ref.grad.abs().max().item() + diff = (p_def.grad - p_ref.grad).abs().max().item() + assert diff <= max(0.1, ref_max * 0.05), ( + f"{name} grad mismatch, max diff {diff} (max |ref| {ref_max})" + ) + + WeightGradStore.flush() + WeightGradStore.pop() + + # grouped_mm is deterministic — weight grads must match tightly. + for p_ref, p_def, name in [ + (experts_ref.gate_up_proj, experts_def.gate_up_proj, "gate_up_proj"), + (experts_ref.down_proj, experts_def.down_proj, "down_proj"), + ]: + assert p_def.grad is not None, f"[deferred] {name} grad missing after pop" + diff = (p_def.grad - p_ref.grad).abs().max().item() + assert torch.allclose(p_def.grad, p_ref.grad, rtol=1e-3, atol=1e-3), ( + f"{name} grad mismatch (deferred vs direct), max diff {diff}" + ) + finally: + WeightGradStore.enabled = False + WeightGradStore.clear()