Skip to content

[CONTEXT PARALLEL] add CP for mamba2#482

Merged
mayank31398 merged 6 commits into
mainfrom
mamba
Jun 26, 2026
Merged

[CONTEXT PARALLEL] add CP for mamba2#482
mayank31398 merged 6 commits into
mainfrom
mamba

Conversation

@mayank31398

Copy link
Copy Markdown
Collaborator

No description provided.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces context parallel (CP) support for the Mamba2 sequence mixer block, implementing a serial prefix scan over CP ranks to compute correct initial SSM states in both the PyTorch-native and CUDA-based forward paths. It also adds corresponding integration tests. The review feedback highlights two key improvements: removing a redundant addition in the CUDA path that unnecessarily triggers a backward pass on unused scan outputs, and expanding the test parametrization to cover the PyTorch-native path in addition to the Triton kernel path.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

dt_softplus=True,
**dt_limit_kwargs,
)
ssm_state_zero = ssm_state_zero + scan_output_zero.sum().to(ssm_state_zero.dtype) * 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Adding scan_output_zero.sum() * 0 to ssm_state_zero forces PyTorch's autograd to execute the backward pass of mamba_chunk_scan_combined for scan_output_zero. Since scan_output_zero is not used elsewhere and its gradient contribution is zero, this redundant backward pass doubles the computation time of the scan kernel during training. You should remove this addition to avoid the performance bottleneck.

Suggested change
ssm_state_zero = ssm_state_zero + scan_output_zero.sum().to(ssm_state_zero.dtype) * 0
# ssm_state_zero is already tracked by autograd; no need to add scan_output_zero

from ..utils import skip_test_if_device_unavailable, slow_test


@pytest.mark.parametrize("use_mamba2_ssm", [True])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test currently only parametrizes use_mamba2_ssm with [True], which means only the Triton kernel path is tested. Since context parallel support was also added to the PyTorch-native path (_torch_forward), you should include False in the parameter list to ensure both paths are covered by the test suite.

Suggested change
@pytest.mark.parametrize("use_mamba2_ssm", [True])
@pytest.mark.parametrize("use_mamba2_ssm", [True, False])

@mayank31398 mayank31398 merged commit df3f94a into main Jun 26, 2026
2 checks passed
@mayank31398 mayank31398 deleted the mamba branch June 26, 2026 01:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant