Merged
Conversation
tjruwase
reviewed
Mar 1, 2026
tjruwase
reviewed
Mar 1, 2026
tjruwase
reviewed
Mar 1, 2026
tjruwase
reviewed
Mar 1, 2026
tjruwase
reviewed
Mar 1, 2026
57619c8 to
af8dff8
Compare
67661a3 to
268a4e6
Compare
Signed-off-by: Neel Dani <neeldani98@gmail.com> Co-authored-by: Ahan Gupta <ahangupta.96@gmail.com>
268a4e6 to
0881fd2
Compare
tohtana
reviewed
Mar 30, 2026
Contributor
tohtana
left a comment
There was a problem hiding this comment.
Thank you for adding a great example! I left comments regarding ring attention.
tohtana
added a commit
to deepspeedai/DeepSpeed
that referenced
this pull request
Mar 30, 2026
# AutoSP: Unlocking Long-Context LLM Training Via Compiler-Based
Sequence Parallelism
## Overview
AutoSP is a compiler optimization pass that shards inputs along the
sequence dimension and enables Ulysses styled sequence parallelism while
preventing graph breaks during `torch.compile()`. All the passes operate
at the Torch IR on the forward graph.
## API Design
### User-Facing Entry Point: `prepare_autosp_inputs()`
Users must explicitly call this function to prepare inputs for AutoSP
compilation:
```python
def prepare_autosp_inputs(
input_id: torch.Tensor,
label_id: torch.Tensor,
position_id: torch.Tensor = None,
attention_mask: torch.Tensor = None,
seq_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
```
**Purpose**: Symbolize sequence dimension and annotate tensors for
identification.
**Operations**:
1. Mark sequence dimension as dynamic using
`torch._dynamo.decorators.mark_dynamic()`
2. Attach metadata tags for tensor identification for auto-sharding:
- `input_id.tag = constants.INPUT_ID_KEY`
- `label_id.tag = constants.LABEL_ID_KEY`
- `position_id.tag = constants.POSITION_ID_KEY` (if provided)
**Rationale**: PyTorch's FX graph tracer requires explicit annotation of
data-dependent dimensions. Marking the sequence dimension as dynamic
prevents symbolic shape propagation from losing dimension information
through reshape/view operations.
## Compilation Passes
### Pass 1: `pass_shard_seq_dim()`
**Objective**: Propagate sharded sequence dimension to all consumers.
**Algorithm**:
1. Extract symbolic sequence dimension from `input_id` shape metadata
2. Locate the symbolic dimension node in the FX graph
3. Create a floor-divide node: `seq_dim / world_size`
4. Perform worklist-based graph traversal to find all direct and
indirect consumers of input node, label node and position id node.
5. Replace symbolic dimension references with sharded dimension in
consumer nodes
**Rationale**: Reshapes and views that consume the sequence dimension as
an argument do not get updated during propagation of symbolic shapes.
This pass explicitly rewires the computation graph to use sharded
dimensions, enabling proper shape inference downstream.
### Pass 2: `pass_shard_input_ids()` / `pass_shard_label_ids()` /
`pass_shard_position_ids()`
**Objective**: Insert slicing operations after input tensors.
**Implementation**: Call `shard_tensor_node()` utility which inserts
slice operations. Each rank retains only the portion of the tensor
corresponding to its sequence partition and drops the remaining buffer.
**Note on `attention_mask`**: Not sharded because it applies to the full
sequence length, not the partitioned dimension.
### Pass 3: `pass_insert_attention_all_to_all()`
**Objective**: Insert all-to-all collectives around attention (Ulysses
styled) to avoid graph breaks during compilation.
**Algorithm**:
1. Identify all SDPA (Scaled Dot-Product Attention) nodes in the graph
2. For each SDPA node with inputs Q, K, V, after each of Q, K, V: insert
A2A scatter heads (dim=1), gather sequence (dim=2)
3. Insert A2A after thre attention output O: scatter sequence (dim=2),
gather heads (dim=1)
**Graph Rewrite Example**:
```
Q [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H]
K [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H]
V [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H]
|
SDPA
|
O [B, N/P, S, H] --A2A(scatter_seq,gather_heads)--> [B, N, S/P, H]
```
**Current support**: Currently only supports
`torch.nn.functional.scaled_dot_product_attention()`. Composite
attention patterns require additional pattern matching logic.
### Pass 4: `pass_propagate_shapes()`
**Objective**: Compute static shapes for all nodes using fake tensor
execution.
**Implementation**:
1. Create `ShapeEnv` for symbolic dimension tracking
2. Construct `FakeTensorMode` with the shape environment
3. Execute `FakeTensorProp.propagate()` to compute shape metadata
### Pass 5: `pass_canonicalize()`
**Objective**: Finalize graph representation.
**Operations**:
1. `eliminate_dead_code()`: Remove unused operations
2. `lint()`: Validate graph structure
3. `recompile()`: Regenerate compiled representation
## Execution Order
```
prepare_autosp_inputs()
↓
pass_shard_seq_dim
↓
pass_shard_input_ids
↓
pass_shard_label_ids
↓
pass_shard_position_ids
↓
pass_insert_attention_all_to_all
↓
pass_propagate_shapes
↓
pass_canonicalize
↓
pass_selective_activation_checkpointing
```
## Memory savings
AutoSP adds some heuristics to torch.compile's partitioniner which
splits the joint graph into the forward and backward graph. Matmul and
related ops are not checkpointed since recomputing them is much cheaper
compared to the attention op, while reducing the peak active memory.
## Reducing gradients across ranks
AutoSP requires an all-reduce to reduce the gradients across ranks. This
is automatically called by DeepSpeed's engine
[here](https://github.com/deepspeedai/DeepSpeed/blob/93524c8931799a7631a2321d7ef4afaff6b6e54b/deepspeed/runtime/engine.py#L2433)
## Known Limitations
1. **Attention Pattern Matching**: Only
`torch.nn.functional.scaled_dot_product_attention()` is supported. Fused
attention implementations require pattern-specific handling.
2. **No Graph Break Requirement**: AutoSP will fail if there are graph
breaks because use-def chains are lost and it becomes tricky to
propagate auto-sharding information across graph modules.
## Example
DeepSpeedExample PR:
deepspeedai/DeepSpeedExamples#999
---------
Signed-off-by: Neel Dani <neeldani98@gmail.com>
Signed-off-by: Ahan Gupta <ahangupta.96@gmail.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Ahan Gupta <ahangupta.96@gmail.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Zhipeng Wang <zhipeng.rainbowserie@gmail.com>
Signed-off-by: Neel Dani <neeldani98@gmail.com>
tohtana
approved these changes
Mar 30, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds an example of using AutoSP
Reference: deepspeedai/DeepSpeed#7860