[CONTEXT PARALLEL] add CP for mamba2#482
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces context parallel (CP) support for the Mamba2 sequence mixer block, implementing a serial prefix scan over CP ranks to compute correct initial SSM states in both the PyTorch-native and CUDA-based forward paths. It also adds corresponding integration tests. The review feedback highlights two key improvements: removing a redundant addition in the CUDA path that unnecessarily triggers a backward pass on unused scan outputs, and expanding the test parametrization to cover the PyTorch-native path in addition to the Triton kernel path.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| dt_softplus=True, | ||
| **dt_limit_kwargs, | ||
| ) | ||
| ssm_state_zero = ssm_state_zero + scan_output_zero.sum().to(ssm_state_zero.dtype) * 0 |
There was a problem hiding this comment.
Adding scan_output_zero.sum() * 0 to ssm_state_zero forces PyTorch's autograd to execute the backward pass of mamba_chunk_scan_combined for scan_output_zero. Since scan_output_zero is not used elsewhere and its gradient contribution is zero, this redundant backward pass doubles the computation time of the scan kernel during training. You should remove this addition to avoid the performance bottleneck.
| ssm_state_zero = ssm_state_zero + scan_output_zero.sum().to(ssm_state_zero.dtype) * 0 | |
| # ssm_state_zero is already tracked by autograd; no need to add scan_output_zero |
| from ..utils import skip_test_if_device_unavailable, slow_test | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("use_mamba2_ssm", [True]) |
There was a problem hiding this comment.
The test currently only parametrizes use_mamba2_ssm with [True], which means only the Triton kernel path is tested. Since context parallel support was also added to the PyTorch-native path (_torch_forward), you should include False in the parameter list to ensure both paths are covered by the test suite.
| @pytest.mark.parametrize("use_mamba2_ssm", [True]) | |
| @pytest.mark.parametrize("use_mamba2_ssm", [True, False]) |
No description provided.