Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 119 additions & 4 deletions lm_engine/modeling_utils/sequence_mixer_blocks/mamba2/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard

from ....enums import Kernel
from ....generation_cache import ConstantCache, GenerationCache, GenerationState
from ....kernels import is_kernel_allowed
from ....parallel import ProcessGroupManager
from ....parameter import (
mark_parameter_as_initialized,
mark_parameter_as_mup_learning_rate,
Expand Down Expand Up @@ -84,6 +86,56 @@ def _segment_sum(input_tensor: torch.Tensor) -> torch.Tensor:
return tensor_segsum


def _all_gather_context_parallel_with_grad(input_tensor: torch.Tensor) -> torch.Tensor:
cp_mesh = ProcessGroupManager.get_context_parallel_mesh()
dtensor = DTensor.from_local(input_tensor.contiguous(), device_mesh=cp_mesh, placements=[Shard(0)])
dtensor = dtensor.redistribute(placements=[Replicate()])

return dtensor.to_local(grad_placements=[Partial()])


class _SerialPrefixScan(torch.autograd.Function):
"""Serial prefix scan over CP ranks with manual backward.

Forward: s[r] = exp_A[r] * s[r-1] + final[r], s[-1] = 0
Backward: chain-rule through the linear recurrence without re-entering autograd.
"""

@staticmethod
def forward(ctx, all_exp_A: torch.Tensor, all_final: torch.Tensor, cp_rank: int) -> torch.Tensor:
# all_exp_A : [cp_world_size, batch, num_heads]
# all_final : [cp_world_size, batch, num_heads, head_dim, state_size]
prev_states = []
s = torch.zeros_like(all_final[0])
for r in range(cp_rank):
prev_states.append(s)
s = all_exp_A[r][:, :, None, None] * s + all_final[r]
ctx.save_for_backward(all_exp_A, *prev_states)
ctx.cp_rank = cp_rank
ctx.all_final_shape = all_final.shape
return s

@staticmethod
def backward(ctx, grad_s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]:
all_exp_A = ctx.saved_tensors[0]
prev_states = ctx.saved_tensors[1:]
cp_rank = ctx.cp_rank

grad_all_exp_A = torch.zeros_like(all_exp_A)
grad_all_final = torch.zeros(ctx.all_final_shape, dtype=grad_s.dtype, device=grad_s.device)

for r in range(cp_rank - 1, -1, -1):
grad_all_final[r] = grad_s
grad_all_exp_A[r] = (grad_s * prev_states[r]).sum(dim=(-2, -1))
grad_s = grad_s * all_exp_A[r][:, :, None, None]

return grad_all_exp_A, grad_all_final, None


def _serial_prefix_scan(all_exp_A: torch.Tensor, all_final: torch.Tensor, cp_rank: int) -> torch.Tensor:
return _SerialPrefixScan.apply(all_exp_A, all_final, cp_rank)


class Mamba2(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
Expand Down Expand Up @@ -195,6 +247,31 @@ def __init__(

self.reset_parameters()

def _get_cp_initial_ssm_state(self, ssm_final_zero: torch.Tensor, dt: torch.Tensor) -> torch.Tensor:
"""Compute the correct initial SSM state for this CP rank.

Uses all-gather + local prefix scan so every CP world size works exactly:
s_init[r] = Phi[r-1] * s_init[r-1] + b[r-1]
where Phi[r] = exp(A * Σ_t dt_t) is the chunk transition and b[r] is the
zero-initial final state.
"""
cp_rank = ProcessGroupManager.get_context_parallel_rank()
cp_world_size = ProcessGroupManager.get_context_parallel_world_size()
batch_size = ssm_final_zero.shape[0]

# Diagonal transition factor: exp(A[h] * Σ_t dt_eff[b,t,h])
A_neg = -torch.exp(self.decay_gate.A_log.float()) # (num_heads,)
exp_A_chunk = torch.exp(A_neg[None, :] * dt.float().sum(dim=1)) # (batch, num_heads)

# All-gather both tensors from every rank (gathered along batch dim 0).
all_exp_A = _all_gather_context_parallel_with_grad(exp_A_chunk)
all_final = _all_gather_context_parallel_with_grad(ssm_final_zero)

all_exp_A = all_exp_A.reshape(cp_world_size, batch_size, self.num_heads)
all_final = all_final.reshape(cp_world_size, batch_size, self.num_heads, self.head_dim, self.ssm_state_size)

return _serial_prefix_scan(all_exp_A, all_final, cp_rank)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -403,15 +480,24 @@ def _torch_forward(

# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if use_precomputed_states:
decay_chunk = torch.exp(_segment_sum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
decay_chunk = decay_chunk.transpose(1, 3)

if ProcessGroupManager.is_context_parallel_enabled():
# Get final state with zero initial to compute the correct CP initial state
states_zero = torch.cat([torch.zeros_like(states[:, :1]), states], dim=1)
new_states_zero = (decay_chunk[..., None, None] * states_zero[:, :, None, ...]).sum(dim=1)
ssm_state_zero = new_states_zero[:, -1]
previous_states = self._get_cp_initial_ssm_state(ssm_state_zero, dt)
previous_states = previous_states[:, None, ...].to(states.dtype)
elif use_precomputed_states:
previous_states = cache_params.get_cache(self.layer_idx, empty_value=None)[1][:, None, ...].to(
device=states.device
)
else:
previous_states = torch.zeros_like(states[:, :1])

states = torch.cat([previous_states, states], dim=1)
decay_chunk = torch.exp(_segment_sum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
decay_chunk = decay_chunk.transpose(1, 3)
new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
states, ssm_state = new_states[:, :-1], new_states[:, -1]

Expand Down Expand Up @@ -538,7 +624,9 @@ def _cuda_forward(
dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}

# 2-4. Fused kernel for conv1d, SSM, and the final projection
if self.training and cache_params is None:
# The fused kernel does not support passing an initial SSM state, so fall through
# to the step-by-step path when context parallelism is active.
if self.training and cache_params is None and not ProcessGroupManager.is_context_parallel_enabled():
out = mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
Expand Down Expand Up @@ -580,6 +668,32 @@ def _cuda_forward(
)

# 3. SSM transformation
if ProcessGroupManager.is_context_parallel_enabled():
# Compute the correct initial SSM state for this CP rank.
# Pass 1: run scan with zero initial to get the local final state.
dt_softplused = F.softplus(dt + self.decay_gate.dt_bias)
if self.time_step_limit != (0.0, float("inf")):
dt_softplused = dt_softplused.clamp(*self.time_step_limit)
scan_output_zero, ssm_state_zero = mamba_chunk_scan_combined(
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
dt,
A,
B.view(batch_size, seq_len, self.n_groups, -1),
C.view(batch_size, seq_len, self.n_groups, -1),
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=None,
return_final_states=True,
dt_bias=self.decay_gate.dt_bias,
dt_softplus=True,
**dt_limit_kwargs,
)
ssm_state_zero = ssm_state_zero + scan_output_zero.sum().to(ssm_state_zero.dtype) * 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Adding scan_output_zero.sum() * 0 to ssm_state_zero forces PyTorch's autograd to execute the backward pass of mamba_chunk_scan_combined for scan_output_zero. Since scan_output_zero is not used elsewhere and its gradient contribution is zero, this redundant backward pass doubles the computation time of the scan kernel during training. You should remove this addition to avoid the performance bottleneck.

Suggested change
ssm_state_zero = ssm_state_zero + scan_output_zero.sum().to(ssm_state_zero.dtype) * 0
# ssm_state_zero is already tracked by autograd; no need to add scan_output_zero

initial_states = self._get_cp_initial_ssm_state(ssm_state_zero, dt_softplused)
else:
initial_states = None

scan_output, ssm_state = mamba_chunk_scan_combined(
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
dt,
Expand All @@ -593,6 +707,7 @@ def _cuda_forward(
return_final_states=True,
dt_bias=self.decay_gate.dt_bias,
dt_softplus=True,
initial_states=initial_states,
**dt_limit_kwargs,
)

Expand Down
102 changes: 102 additions & 0 deletions tests/context_parallel/mamba2_cp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# **************************************************
# Copyright (c) 2026, Mayank Mishra
# **************************************************

import argparse
import os

import torch
import torch.distributed
from torch.testing import assert_close

from lm_engine.enums import Kernel
from lm_engine.kernels import enable_kernels
from lm_engine.modeling_utils.sequence_mixer_blocks.mamba2 import Mamba2, Mamba2Args
from lm_engine.parallel import ProcessGroupManager, prepare_context_parallel_input


parser = argparse.ArgumentParser()
parser.add_argument("--use-mamba2-ssm", action="store_true")
args = parser.parse_args()

cp_world_size = int(os.getenv("WORLD_SIZE"))
ProcessGroupManager(context_parallel_world_size=cp_world_size)

rank = ProcessGroupManager.get_context_parallel_rank()
device = torch.cuda.current_device()

_HIDDEN_SIZE = 64
_INTERMEDIATE_SIZE = 128
_NUM_HEADS = 8
_STATE_SIZE = 16
_N_GROUPS = 1
_CHUNK_LEN = 32 # sequence length per CP rank; must be a multiple of chunk_size
_BATCH = 2

config = Mamba2Args(
state_size=_STATE_SIZE,
intermediate_size=_INTERMEDIATE_SIZE,
num_heads=_NUM_HEADS,
conv_kernel_size=4,
activation_function="silu",
num_groups=_N_GROUPS,
chunk_size=16,
normalization_function="rmsnorm",
)

torch.manual_seed(42)
mamba2 = Mamba2(
hidden_size=_HIDDEN_SIZE,
config=config,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
m_width=1.0,
init_method="normal",
num_layers=1,
layer_idx=0,
use_depth_scaled_init=False,
).to(device)
mamba2.eval()

torch.manual_seed(0)
x_full = torch.randn(_BATCH, cp_world_size * _CHUNK_LEN, _HIDDEN_SIZE, device=device)
x_local = prepare_context_parallel_input((x_full,))[0]

kernels = [Kernel.mamba2_ssm] if args.use_mamba2_ssm else []
cp_group = ProcessGroupManager.get_context_parallel_group()

# ---- forward ----
with enable_kernels(kernels):
out_local = mamba2(x_local.detach())

parts = [torch.zeros_like(out_local) for _ in range(cp_world_size)]
torch.distributed.all_gather(parts, out_local.detach().contiguous(), group=cp_group)
out_cp_full = torch.cat(parts, dim=1)

if rank == 0:
with enable_kernels(kernels), ProcessGroupManager.set_dummy_context_parallel_world_size(1):
out_ref = mamba2(x_full.detach())

assert_close(out_cp_full, out_ref)

# ---- backward (correctness: compare input grads against non-CP reference) ----
x_local_bwd = x_local.detach().requires_grad_(True)
with enable_kernels(kernels):
out_local_bwd = mamba2(x_local_bwd)
out_local_bwd.sum().backward()

assert x_local_bwd.grad is not None

grad_parts = [torch.zeros_like(x_local_bwd.grad) for _ in range(cp_world_size)]
torch.distributed.all_gather(grad_parts, x_local_bwd.grad.contiguous(), group=cp_group)
grad_cp_full = torch.cat(grad_parts, dim=1)

if rank == 0:
x_full_bwd = x_full.detach().requires_grad_(True)
with enable_kernels(kernels), ProcessGroupManager.set_dummy_context_parallel_world_size(1):
out_ref_bwd = mamba2(x_full_bwd)
out_ref_bwd.sum().backward()

assert_close(grad_cp_full, x_full_bwd.grad)

ProcessGroupManager.destroy_process_groups()
38 changes: 38 additions & 0 deletions tests/context_parallel/mamba2_cp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# **************************************************
# Copyright (c) 2026, Mayank Mishra
# **************************************************

import subprocess

import pytest
import torch

from lm_engine.utils import is_mamba_2_ssm_available

from ..utils import skip_test_if_device_unavailable, slow_test


@pytest.mark.parametrize("use_mamba2_ssm", [True])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test currently only parametrizes use_mamba2_ssm with [True], which means only the Triton kernel path is tested. Since context parallel support was also added to the PyTorch-native path (_torch_forward), you should include False in the parameter list to ensure both paths are covered by the test suite.

Suggested change
@pytest.mark.parametrize("use_mamba2_ssm", [True])
@pytest.mark.parametrize("use_mamba2_ssm", [True, False])

@slow_test
def test_mamba2_cp(use_mamba2_ssm: bool) -> None:
skip_test_if_device_unavailable(torch.device("cuda"))

if use_mamba2_ssm and not is_mamba_2_ssm_available():
pytest.skip("mamba_ssm unavailable")

gpus_per_node = torch.cuda.device_count()
if gpus_per_node < 2:
pytest.skip("context parallel requires at least 2 GPUs")

command = [
"torchrun",
"--nproc_per_node",
str(gpus_per_node),
"-m",
"tests.context_parallel.mamba2_cp",
]

if use_mamba2_ssm:
command.append("--use-mamba2-ssm")

subprocess.run(command, check=True)
Loading