diff --git a/lm_engine/modeling_utils/sequence_mixer_blocks/mamba2/module.py b/lm_engine/modeling_utils/sequence_mixer_blocks/mamba2/module.py index f4ca1bd7d..3271be3cb 100644 --- a/lm_engine/modeling_utils/sequence_mixer_blocks/mamba2/module.py +++ b/lm_engine/modeling_utils/sequence_mixer_blocks/mamba2/module.py @@ -7,10 +7,12 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard from ....enums import Kernel from ....generation_cache import ConstantCache, GenerationCache, GenerationState from ....kernels import is_kernel_allowed +from ....parallel import ProcessGroupManager from ....parameter import ( mark_parameter_as_initialized, mark_parameter_as_mup_learning_rate, @@ -84,6 +86,56 @@ def _segment_sum(input_tensor: torch.Tensor) -> torch.Tensor: return tensor_segsum +def _all_gather_context_parallel_with_grad(input_tensor: torch.Tensor) -> torch.Tensor: + cp_mesh = ProcessGroupManager.get_context_parallel_mesh() + dtensor = DTensor.from_local(input_tensor.contiguous(), device_mesh=cp_mesh, placements=[Shard(0)]) + dtensor = dtensor.redistribute(placements=[Replicate()]) + + return dtensor.to_local(grad_placements=[Partial()]) + + +class _SerialPrefixScan(torch.autograd.Function): + """Serial prefix scan over CP ranks with manual backward. + + Forward: s[r] = exp_A[r] * s[r-1] + final[r], s[-1] = 0 + Backward: chain-rule through the linear recurrence without re-entering autograd. + """ + + @staticmethod + def forward(ctx, all_exp_A: torch.Tensor, all_final: torch.Tensor, cp_rank: int) -> torch.Tensor: + # all_exp_A : [cp_world_size, batch, num_heads] + # all_final : [cp_world_size, batch, num_heads, head_dim, state_size] + prev_states = [] + s = torch.zeros_like(all_final[0]) + for r in range(cp_rank): + prev_states.append(s) + s = all_exp_A[r][:, :, None, None] * s + all_final[r] + ctx.save_for_backward(all_exp_A, *prev_states) + ctx.cp_rank = cp_rank + ctx.all_final_shape = all_final.shape + return s + + @staticmethod + def backward(ctx, grad_s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: + all_exp_A = ctx.saved_tensors[0] + prev_states = ctx.saved_tensors[1:] + cp_rank = ctx.cp_rank + + grad_all_exp_A = torch.zeros_like(all_exp_A) + grad_all_final = torch.zeros(ctx.all_final_shape, dtype=grad_s.dtype, device=grad_s.device) + + for r in range(cp_rank - 1, -1, -1): + grad_all_final[r] = grad_s + grad_all_exp_A[r] = (grad_s * prev_states[r]).sum(dim=(-2, -1)) + grad_s = grad_s * all_exp_A[r][:, :, None, None] + + return grad_all_exp_A, grad_all_final, None + + +def _serial_prefix_scan(all_exp_A: torch.Tensor, all_final: torch.Tensor, cp_rank: int) -> torch.Tensor: + return _SerialPrefixScan.apply(all_exp_A, all_final, cp_rank) + + class Mamba2(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. @@ -195,6 +247,31 @@ def __init__( self.reset_parameters() + def _get_cp_initial_ssm_state(self, ssm_final_zero: torch.Tensor, dt: torch.Tensor) -> torch.Tensor: + """Compute the correct initial SSM state for this CP rank. + + Uses all-gather + local prefix scan so every CP world size works exactly: + s_init[r] = Phi[r-1] * s_init[r-1] + b[r-1] + where Phi[r] = exp(A * Σ_t dt_t) is the chunk transition and b[r] is the + zero-initial final state. + """ + cp_rank = ProcessGroupManager.get_context_parallel_rank() + cp_world_size = ProcessGroupManager.get_context_parallel_world_size() + batch_size = ssm_final_zero.shape[0] + + # Diagonal transition factor: exp(A[h] * Σ_t dt_eff[b,t,h]) + A_neg = -torch.exp(self.decay_gate.A_log.float()) # (num_heads,) + exp_A_chunk = torch.exp(A_neg[None, :] * dt.float().sum(dim=1)) # (batch, num_heads) + + # All-gather both tensors from every rank (gathered along batch dim 0). + all_exp_A = _all_gather_context_parallel_with_grad(exp_A_chunk) + all_final = _all_gather_context_parallel_with_grad(ssm_final_zero) + + all_exp_A = all_exp_A.reshape(cp_world_size, batch_size, self.num_heads) + all_final = all_final.reshape(cp_world_size, batch_size, self.num_heads, self.head_dim, self.ssm_state_size) + + return _serial_prefix_scan(all_exp_A, all_final, cp_rank) + def forward( self, hidden_states: torch.Tensor, @@ -403,15 +480,24 @@ def _torch_forward( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - if use_precomputed_states: + decay_chunk = torch.exp(_segment_sum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + + if ProcessGroupManager.is_context_parallel_enabled(): + # Get final state with zero initial to compute the correct CP initial state + states_zero = torch.cat([torch.zeros_like(states[:, :1]), states], dim=1) + new_states_zero = (decay_chunk[..., None, None] * states_zero[:, :, None, ...]).sum(dim=1) + ssm_state_zero = new_states_zero[:, -1] + previous_states = self._get_cp_initial_ssm_state(ssm_state_zero, dt) + previous_states = previous_states[:, None, ...].to(states.dtype) + elif use_precomputed_states: previous_states = cache_params.get_cache(self.layer_idx, empty_value=None)[1][:, None, ...].to( device=states.device ) else: previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) - decay_chunk = torch.exp(_segment_sum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) - decay_chunk = decay_chunk.transpose(1, 3) new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) states, ssm_state = new_states[:, :-1], new_states[:, -1] @@ -538,7 +624,9 @@ def _cuda_forward( dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} # 2-4. Fused kernel for conv1d, SSM, and the final projection - if self.training and cache_params is None: + # The fused kernel does not support passing an initial SSM state, so fall through + # to the step-by-step path when context parallelism is active. + if self.training and cache_params is None and not ProcessGroupManager.is_context_parallel_enabled(): out = mamba_split_conv1d_scan_combined( projected_states, self.conv1d.weight.squeeze(1), @@ -580,6 +668,32 @@ def _cuda_forward( ) # 3. SSM transformation + if ProcessGroupManager.is_context_parallel_enabled(): + # Compute the correct initial SSM state for this CP rank. + # Pass 1: run scan with zero initial to get the local final state. + dt_softplused = F.softplus(dt + self.decay_gate.dt_bias) + if self.time_step_limit != (0.0, float("inf")): + dt_softplused = dt_softplused.clamp(*self.time_step_limit) + scan_output_zero, ssm_state_zero = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.decay_gate.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + ssm_state_zero = ssm_state_zero + scan_output_zero.sum().to(ssm_state_zero.dtype) * 0 + initial_states = self._get_cp_initial_ssm_state(ssm_state_zero, dt_softplused) + else: + initial_states = None + scan_output, ssm_state = mamba_chunk_scan_combined( hidden_states.view(batch_size, seq_len, -1, self.head_dim), dt, @@ -593,6 +707,7 @@ def _cuda_forward( return_final_states=True, dt_bias=self.decay_gate.dt_bias, dt_softplus=True, + initial_states=initial_states, **dt_limit_kwargs, ) diff --git a/tests/context_parallel/mamba2_cp.py b/tests/context_parallel/mamba2_cp.py new file mode 100644 index 000000000..1c619d589 --- /dev/null +++ b/tests/context_parallel/mamba2_cp.py @@ -0,0 +1,102 @@ +# ************************************************** +# Copyright (c) 2026, Mayank Mishra +# ************************************************** + +import argparse +import os + +import torch +import torch.distributed +from torch.testing import assert_close + +from lm_engine.enums import Kernel +from lm_engine.kernels import enable_kernels +from lm_engine.modeling_utils.sequence_mixer_blocks.mamba2 import Mamba2, Mamba2Args +from lm_engine.parallel import ProcessGroupManager, prepare_context_parallel_input + + +parser = argparse.ArgumentParser() +parser.add_argument("--use-mamba2-ssm", action="store_true") +args = parser.parse_args() + +cp_world_size = int(os.getenv("WORLD_SIZE")) +ProcessGroupManager(context_parallel_world_size=cp_world_size) + +rank = ProcessGroupManager.get_context_parallel_rank() +device = torch.cuda.current_device() + +_HIDDEN_SIZE = 64 +_INTERMEDIATE_SIZE = 128 +_NUM_HEADS = 8 +_STATE_SIZE = 16 +_N_GROUPS = 1 +_CHUNK_LEN = 32 # sequence length per CP rank; must be a multiple of chunk_size +_BATCH = 2 + +config = Mamba2Args( + state_size=_STATE_SIZE, + intermediate_size=_INTERMEDIATE_SIZE, + num_heads=_NUM_HEADS, + conv_kernel_size=4, + activation_function="silu", + num_groups=_N_GROUPS, + chunk_size=16, + normalization_function="rmsnorm", +) + +torch.manual_seed(42) +mamba2 = Mamba2( + hidden_size=_HIDDEN_SIZE, + config=config, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + m_width=1.0, + init_method="normal", + num_layers=1, + layer_idx=0, + use_depth_scaled_init=False, +).to(device) +mamba2.eval() + +torch.manual_seed(0) +x_full = torch.randn(_BATCH, cp_world_size * _CHUNK_LEN, _HIDDEN_SIZE, device=device) +x_local = prepare_context_parallel_input((x_full,))[0] + +kernels = [Kernel.mamba2_ssm] if args.use_mamba2_ssm else [] +cp_group = ProcessGroupManager.get_context_parallel_group() + +# ---- forward ---- +with enable_kernels(kernels): + out_local = mamba2(x_local.detach()) + +parts = [torch.zeros_like(out_local) for _ in range(cp_world_size)] +torch.distributed.all_gather(parts, out_local.detach().contiguous(), group=cp_group) +out_cp_full = torch.cat(parts, dim=1) + +if rank == 0: + with enable_kernels(kernels), ProcessGroupManager.set_dummy_context_parallel_world_size(1): + out_ref = mamba2(x_full.detach()) + + assert_close(out_cp_full, out_ref) + +# ---- backward (correctness: compare input grads against non-CP reference) ---- +x_local_bwd = x_local.detach().requires_grad_(True) +with enable_kernels(kernels): + out_local_bwd = mamba2(x_local_bwd) + out_local_bwd.sum().backward() + +assert x_local_bwd.grad is not None + +grad_parts = [torch.zeros_like(x_local_bwd.grad) for _ in range(cp_world_size)] +torch.distributed.all_gather(grad_parts, x_local_bwd.grad.contiguous(), group=cp_group) +grad_cp_full = torch.cat(grad_parts, dim=1) + +if rank == 0: + x_full_bwd = x_full.detach().requires_grad_(True) + with enable_kernels(kernels), ProcessGroupManager.set_dummy_context_parallel_world_size(1): + out_ref_bwd = mamba2(x_full_bwd) + out_ref_bwd.sum().backward() + + assert_close(grad_cp_full, x_full_bwd.grad) + +ProcessGroupManager.destroy_process_groups() diff --git a/tests/context_parallel/mamba2_cp_test.py b/tests/context_parallel/mamba2_cp_test.py new file mode 100644 index 000000000..56cdc19f3 --- /dev/null +++ b/tests/context_parallel/mamba2_cp_test.py @@ -0,0 +1,38 @@ +# ************************************************** +# Copyright (c) 2026, Mayank Mishra +# ************************************************** + +import subprocess + +import pytest +import torch + +from lm_engine.utils import is_mamba_2_ssm_available + +from ..utils import skip_test_if_device_unavailable, slow_test + + +@pytest.mark.parametrize("use_mamba2_ssm", [True]) +@slow_test +def test_mamba2_cp(use_mamba2_ssm: bool) -> None: + skip_test_if_device_unavailable(torch.device("cuda")) + + if use_mamba2_ssm and not is_mamba_2_ssm_available(): + pytest.skip("mamba_ssm unavailable") + + gpus_per_node = torch.cuda.device_count() + if gpus_per_node < 2: + pytest.skip("context parallel requires at least 2 GPUs") + + command = [ + "torchrun", + "--nproc_per_node", + str(gpus_per_node), + "-m", + "tests.context_parallel.mamba2_cp", + ] + + if use_mamba2_ssm: + command.append("--use-mamba2-ssm") + + subprocess.run(command, check=True)