-
Notifications
You must be signed in to change notification settings - Fork 30
[CONTEXT PARALLEL] add CP for mamba2 #482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| # ************************************************** | ||
| # Copyright (c) 2026, Mayank Mishra | ||
| # ************************************************** | ||
|
|
||
| import argparse | ||
| import os | ||
|
|
||
| import torch | ||
| import torch.distributed | ||
| from torch.testing import assert_close | ||
|
|
||
| from lm_engine.enums import Kernel | ||
| from lm_engine.kernels import enable_kernels | ||
| from lm_engine.modeling_utils.sequence_mixer_blocks.mamba2 import Mamba2, Mamba2Args | ||
| from lm_engine.parallel import ProcessGroupManager, prepare_context_parallel_input | ||
|
|
||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--use-mamba2-ssm", action="store_true") | ||
| args = parser.parse_args() | ||
|
|
||
| cp_world_size = int(os.getenv("WORLD_SIZE")) | ||
| ProcessGroupManager(context_parallel_world_size=cp_world_size) | ||
|
|
||
| rank = ProcessGroupManager.get_context_parallel_rank() | ||
| device = torch.cuda.current_device() | ||
|
|
||
| _HIDDEN_SIZE = 64 | ||
| _INTERMEDIATE_SIZE = 128 | ||
| _NUM_HEADS = 8 | ||
| _STATE_SIZE = 16 | ||
| _N_GROUPS = 1 | ||
| _CHUNK_LEN = 32 # sequence length per CP rank; must be a multiple of chunk_size | ||
| _BATCH = 2 | ||
|
|
||
| config = Mamba2Args( | ||
| state_size=_STATE_SIZE, | ||
| intermediate_size=_INTERMEDIATE_SIZE, | ||
| num_heads=_NUM_HEADS, | ||
| conv_kernel_size=4, | ||
| activation_function="silu", | ||
| num_groups=_N_GROUPS, | ||
| chunk_size=16, | ||
| normalization_function="rmsnorm", | ||
| ) | ||
|
|
||
| torch.manual_seed(42) | ||
| mamba2 = Mamba2( | ||
| hidden_size=_HIDDEN_SIZE, | ||
| config=config, | ||
| layer_norm_epsilon=1e-5, | ||
| initializer_range=0.02, | ||
| m_width=1.0, | ||
| init_method="normal", | ||
| num_layers=1, | ||
| layer_idx=0, | ||
| use_depth_scaled_init=False, | ||
| ).to(device) | ||
| mamba2.eval() | ||
|
|
||
| torch.manual_seed(0) | ||
| x_full = torch.randn(_BATCH, cp_world_size * _CHUNK_LEN, _HIDDEN_SIZE, device=device) | ||
| x_local = prepare_context_parallel_input((x_full,))[0] | ||
|
|
||
| kernels = [Kernel.mamba2_ssm] if args.use_mamba2_ssm else [] | ||
| cp_group = ProcessGroupManager.get_context_parallel_group() | ||
|
|
||
| # ---- forward ---- | ||
| with enable_kernels(kernels): | ||
| out_local = mamba2(x_local.detach()) | ||
|
|
||
| parts = [torch.zeros_like(out_local) for _ in range(cp_world_size)] | ||
| torch.distributed.all_gather(parts, out_local.detach().contiguous(), group=cp_group) | ||
| out_cp_full = torch.cat(parts, dim=1) | ||
|
|
||
| if rank == 0: | ||
| with enable_kernels(kernels), ProcessGroupManager.set_dummy_context_parallel_world_size(1): | ||
| out_ref = mamba2(x_full.detach()) | ||
|
|
||
| assert_close(out_cp_full, out_ref) | ||
|
|
||
| # ---- backward (correctness: compare input grads against non-CP reference) ---- | ||
| x_local_bwd = x_local.detach().requires_grad_(True) | ||
| with enable_kernels(kernels): | ||
| out_local_bwd = mamba2(x_local_bwd) | ||
| out_local_bwd.sum().backward() | ||
|
|
||
| assert x_local_bwd.grad is not None | ||
|
|
||
| grad_parts = [torch.zeros_like(x_local_bwd.grad) for _ in range(cp_world_size)] | ||
| torch.distributed.all_gather(grad_parts, x_local_bwd.grad.contiguous(), group=cp_group) | ||
| grad_cp_full = torch.cat(grad_parts, dim=1) | ||
|
|
||
| if rank == 0: | ||
| x_full_bwd = x_full.detach().requires_grad_(True) | ||
| with enable_kernels(kernels), ProcessGroupManager.set_dummy_context_parallel_world_size(1): | ||
| out_ref_bwd = mamba2(x_full_bwd) | ||
| out_ref_bwd.sum().backward() | ||
|
|
||
| assert_close(grad_cp_full, x_full_bwd.grad) | ||
|
|
||
| ProcessGroupManager.destroy_process_groups() |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,38 @@ | ||||||
| # ************************************************** | ||||||
| # Copyright (c) 2026, Mayank Mishra | ||||||
| # ************************************************** | ||||||
|
|
||||||
| import subprocess | ||||||
|
|
||||||
| import pytest | ||||||
| import torch | ||||||
|
|
||||||
| from lm_engine.utils import is_mamba_2_ssm_available | ||||||
|
|
||||||
| from ..utils import skip_test_if_device_unavailable, slow_test | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.parametrize("use_mamba2_ssm", [True]) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test currently only parametrizes
Suggested change
|
||||||
| @slow_test | ||||||
| def test_mamba2_cp(use_mamba2_ssm: bool) -> None: | ||||||
| skip_test_if_device_unavailable(torch.device("cuda")) | ||||||
|
|
||||||
| if use_mamba2_ssm and not is_mamba_2_ssm_available(): | ||||||
| pytest.skip("mamba_ssm unavailable") | ||||||
|
|
||||||
| gpus_per_node = torch.cuda.device_count() | ||||||
| if gpus_per_node < 2: | ||||||
| pytest.skip("context parallel requires at least 2 GPUs") | ||||||
|
|
||||||
| command = [ | ||||||
| "torchrun", | ||||||
| "--nproc_per_node", | ||||||
| str(gpus_per_node), | ||||||
| "-m", | ||||||
| "tests.context_parallel.mamba2_cp", | ||||||
| ] | ||||||
|
|
||||||
| if use_mamba2_ssm: | ||||||
| command.append("--use-mamba2-ssm") | ||||||
|
|
||||||
| subprocess.run(command, check=True) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding
scan_output_zero.sum() * 0tossm_state_zeroforces PyTorch's autograd to execute the backward pass ofmamba_chunk_scan_combinedforscan_output_zero. Sincescan_output_zerois not used elsewhere and its gradient contribution is zero, this redundant backward pass doubles the computation time of the scan kernel during training. You should remove this addition to avoid the performance bottleneck.