Skip to content

Add AutoSP example#999

Merged
tohtana merged 2 commits intodeepspeedai:masterfrom
neeldani:autosp-example
Mar 30, 2026
Merged

Add AutoSP example#999
tohtana merged 2 commits intodeepspeedai:masterfrom
neeldani:autosp-example

Conversation

@neeldani
Copy link
Copy Markdown
Contributor

@neeldani neeldani commented Feb 23, 2026

This PR adds an example of using AutoSP

Reference: deepspeedai/DeepSpeed#7860

@neeldani neeldani marked this pull request as ready for review March 12, 2026 06:32
Signed-off-by: Neel Dani <neeldani98@gmail.com>
Co-authored-by: Ahan Gupta <ahangupta.96@gmail.com>
Copy link
Copy Markdown
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding a great example! I left comments regarding ring attention.

tohtana added a commit to deepspeedai/DeepSpeed that referenced this pull request Mar 30, 2026
# AutoSP: Unlocking Long-Context LLM Training Via Compiler-Based
Sequence Parallelism

## Overview

AutoSP is a compiler optimization pass that shards inputs along the
sequence dimension and enables Ulysses styled sequence parallelism while
preventing graph breaks during `torch.compile()`. All the passes operate
at the Torch IR on the forward graph.

## API Design

### User-Facing Entry Point: `prepare_autosp_inputs()`

Users must explicitly call this function to prepare inputs for AutoSP
compilation:

```python
def prepare_autosp_inputs(
    input_id: torch.Tensor,
    label_id: torch.Tensor,
    position_id: torch.Tensor = None,
    attention_mask: torch.Tensor = None,
    seq_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
```

**Purpose**: Symbolize sequence dimension and annotate tensors for
identification.

**Operations**:
1. Mark sequence dimension as dynamic using
`torch._dynamo.decorators.mark_dynamic()`
2. Attach metadata tags for tensor identification for auto-sharding:
   - `input_id.tag = constants.INPUT_ID_KEY`
   - `label_id.tag = constants.LABEL_ID_KEY`
   - `position_id.tag = constants.POSITION_ID_KEY` (if provided)

**Rationale**: PyTorch's FX graph tracer requires explicit annotation of
data-dependent dimensions. Marking the sequence dimension as dynamic
prevents symbolic shape propagation from losing dimension information
through reshape/view operations.

## Compilation Passes

### Pass 1: `pass_shard_seq_dim()`

**Objective**: Propagate sharded sequence dimension to all consumers.

**Algorithm**:
1. Extract symbolic sequence dimension from `input_id` shape metadata
2. Locate the symbolic dimension node in the FX graph
3. Create a floor-divide node: `seq_dim / world_size`
4. Perform worklist-based graph traversal to find all direct and
indirect consumers of input node, label node and position id node.
5. Replace symbolic dimension references with sharded dimension in
consumer nodes

**Rationale**: Reshapes and views that consume the sequence dimension as
an argument do not get updated during propagation of symbolic shapes.
This pass explicitly rewires the computation graph to use sharded
dimensions, enabling proper shape inference downstream.

### Pass 2: `pass_shard_input_ids()` / `pass_shard_label_ids()` /
`pass_shard_position_ids()`

**Objective**: Insert slicing operations after input tensors.

**Implementation**: Call `shard_tensor_node()` utility which inserts
slice operations. Each rank retains only the portion of the tensor
corresponding to its sequence partition and drops the remaining buffer.

**Note on `attention_mask`**: Not sharded because it applies to the full
sequence length, not the partitioned dimension.

### Pass 3: `pass_insert_attention_all_to_all()`

**Objective**: Insert all-to-all collectives around attention (Ulysses
styled) to avoid graph breaks during compilation.

**Algorithm**:
1. Identify all SDPA (Scaled Dot-Product Attention) nodes in the graph
2. For each SDPA node with inputs Q, K, V, after each of Q, K, V: insert
A2A scatter heads (dim=1), gather sequence (dim=2)
3. Insert A2A after thre attention output O: scatter sequence (dim=2),
gather heads (dim=1)

**Graph Rewrite Example**:
```
Q [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H]
K [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H]
V [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H]
                     |
                    SDPA
                     |
O [B, N/P, S, H] --A2A(scatter_seq,gather_heads)--> [B, N, S/P, H]
```

**Current support**: Currently only supports
`torch.nn.functional.scaled_dot_product_attention()`. Composite
attention patterns require additional pattern matching logic.

### Pass 4: `pass_propagate_shapes()`

**Objective**: Compute static shapes for all nodes using fake tensor
execution.

**Implementation**:
1. Create `ShapeEnv` for symbolic dimension tracking
2. Construct `FakeTensorMode` with the shape environment
3. Execute `FakeTensorProp.propagate()` to compute shape metadata

### Pass 5: `pass_canonicalize()`

**Objective**: Finalize graph representation.

**Operations**:
1. `eliminate_dead_code()`: Remove unused operations
2. `lint()`: Validate graph structure
3. `recompile()`: Regenerate compiled representation

## Execution Order

```
prepare_autosp_inputs()
    ↓
pass_shard_seq_dim
    ↓
pass_shard_input_ids
    ↓
pass_shard_label_ids
    ↓
pass_shard_position_ids
    ↓
pass_insert_attention_all_to_all
    ↓
pass_propagate_shapes
    ↓
pass_canonicalize
    ↓
pass_selective_activation_checkpointing

```

## Memory savings
AutoSP adds some heuristics to torch.compile's partitioniner which
splits the joint graph into the forward and backward graph. Matmul and
related ops are not checkpointed since recomputing them is much cheaper
compared to the attention op, while reducing the peak active memory.

## Reducing gradients across ranks
AutoSP requires an all-reduce to reduce the gradients across ranks. This
is automatically called by DeepSpeed's engine
[here](https://github.com/deepspeedai/DeepSpeed/blob/93524c8931799a7631a2321d7ef4afaff6b6e54b/deepspeed/runtime/engine.py#L2433)

## Known Limitations

1. **Attention Pattern Matching**: Only
`torch.nn.functional.scaled_dot_product_attention()` is supported. Fused
attention implementations require pattern-specific handling.
2. **No Graph Break Requirement**: AutoSP will fail if there are graph
breaks because use-def chains are lost and it becomes tricky to
propagate auto-sharding information across graph modules.


## Example 

DeepSpeedExample PR:
deepspeedai/DeepSpeedExamples#999

---------

Signed-off-by: Neel Dani <neeldani98@gmail.com>
Signed-off-by: Ahan Gupta <ahangupta.96@gmail.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Ahan Gupta <ahangupta.96@gmail.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Zhipeng Wang <zhipeng.rainbowserie@gmail.com>
Signed-off-by: Neel Dani <neeldani98@gmail.com>
Copy link
Copy Markdown
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @neeldani!

@tohtana tohtana merged commit 45b4b71 into deepspeedai:master Mar 30, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants