Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lm_engine/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _sharding_function(parameter: nn.Parameter) -> Shard:

if fsdp_algorithm == 3:
backend = get_simple_fsdp_compile_backend(
fsdp_reshard_after_forward=zero3, auto_bucketing=True, backend=backend
fsdp_reshard_after_forward=zero3, bucketing_mode="auto", backend=backend
)

fullgraph = True
Expand Down
2 changes: 1 addition & 1 deletion lm_engine/hf_models/modeling_utils/activations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.relu(x)

y = x.to(torch.float32)
y *= y
y = y**2
y = y.type_as(x)

return y
Expand Down
15 changes: 14 additions & 1 deletion lm_engine/parallel/simple_fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,18 @@
# Copyright (c) 2026, Mayank Mishra
# **************************************************

from .compile import get_simple_fsdp_compile_backend
from .compile import (
annotate_module_fqns,
async_tensor_parallel_pass,
autobucketing_reordering_pass,
create_extra_fsdp_pg,
get_extra_fsdp_pg_name,
get_simple_fsdp_compile_backend,
normalize_view_ops_as_reshape,
reassign_to_pg_pass,
remove_detach_pass,
remove_identity_slice_pass,
remove_identity_view_pass,
transformer_block_bucketing_reordering_pass,
)
from .fsdp import MixedPrecisionPolicy, data_parallel
177 changes: 146 additions & 31 deletions lm_engine/parallel/simple_fsdp/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,77 @@

# Adapted from torchtitan's graph_trainer JIT compile pipeline:
# https://github.com/pytorch/torchtitan/tree/main/torchtitan/experiments/graph_trainer
# Subset: auto-bucketing + reshard-after-forward joint pass for SimpleFSDP.
# Original copyright (c) Meta Platforms, Inc. and affiliates.

from typing import Any, Callable
import warnings
from collections.abc import Callable
from functools import partial
from typing import Any

import torch
import torch._functorch.config as functorch_config
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
from torch._inductor.config import aten_distributed_optimizations
from torch._inductor.fx_passes.bucketing import is_all_gather_into_tensor, is_wait_tensor
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing,
schedule_overlap_bucketing_from_inductor_configs,
)
from torch._inductor.fx_passes.micro_pipeline_tp import micro_pipeline_tp_pass
from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing
from torch._inductor.fx_passes.overlap_scheduling import get_group_name, schedule_overlap_bucketing
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from torch.utils.checkpoint import CheckpointPolicy


def async_tensor_parallel_pass(gm: torch.fx.GraphModule, example_inputs: tuple) -> torch.fx.GraphModule:
"""Pipeline TP collectives with matmuls via symmetric memory.

Fuses all-gather + matmul into ``symm_mem.fused_all_gather_matmul`` and
matmul + reduce-scatter into ``symm_mem.fused_matmul_reduce_scatter``.
"""

c10d = torch.ops._c10d_functional
collective_targets = {
c10d.all_gather_into_tensor.default,
c10d.reduce_scatter_tensor.default,
}
registered: set[str] = set()
for node in gm.graph.nodes:
if node.target not in collective_targets:
continue
pg = get_group_name(node)
if pg not in registered:
registered.add(pg)
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
enable_symm_mem_for_group(pg)
Comment thread
mayank31398 marked this conversation as resolved.

micro_pipeline_tp_pass(gm.graph)
gm.graph.lint()
gm.recompile()
return gm


def reassign_to_pg_pass(
gm: torch.fx.GraphModule,
example_inputs: Any = None,
*,
source_pg_name: str,
target_pg_name: str,
) -> torch.fx.GraphModule:
"""Rewrite all-gather nodes from ``source_pg_name`` to ``target_pg_name``.

Must be applied **before** bucketing passes so bucketed AGs inherit the
new PG.
"""
count = 0
for node in gm.graph.nodes:
if is_all_gather_into_tensor(node) and node.args[2] == source_pg_name:
node.args = (node.args[0], node.args[1], target_pg_name)
count += 1
gm.recompile()
return gm


def _is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
"""True if `node` is the wait_tensor of an FSDP all_gather that can be prefetched.
"""True if ``node`` is the wait_tensor of an FSDP all_gather that can be prefetched.

Mirrors torchtitan's heuristic: walk back from the wait through chains of
single-input ops to a graph placeholder. If reachable, the all_gather is
Expand All @@ -40,16 +92,18 @@ def _is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
def annotate_fsdp_all_gather(gm: torch.fx.GraphModule, reshard_after_forward: bool) -> torch.fx.GraphModule:
"""Tag SimpleFSDP all_gather nodes for reshard-after-forward behavior.

When `reshard_after_forward=True`, all_gathers and their immediate consumers
(wait + post-wait slice + dtype convert) are marked MUST_RECOMPUTE, so the
backward re-runs them rather than holding gathered params live through fwd.
When ``reshard_after_forward=True``, all_gathers and their immediate
consumers (wait + post-wait slice + dtype convert) are marked
MUST_RECOMPUTE so the backward re-runs them rather than holding gathered
params live through fwd.

When False, they are marked MUST_SAVE (params stay all-gathered).

`ac_graph_id=100000` ensures the recompute decision is not influenced by
neighboring AC regions (a partitioner workaround).
``ac_graph_id=100000`` prevents the recompute decision from being
influenced by neighbouring AC regions (partitioner workaround).

Run as a `joint_custom_pass` so AC tags survive the joint -> fwd/bwd partition.
Run as a ``joint_custom_pass`` so AC tags survive the joint->fwd/bwd
partition.
"""
graph = gm.graph

Expand Down Expand Up @@ -89,32 +143,71 @@ def fsdp_reshard_after_fwd_pass(
def autobucketing_reordering_pass(gm: torch.fx.GraphModule, example_inputs: Any | None = None) -> torch.fx.GraphModule:
"""Bucket and reorder collectives for compute/comm overlap.

Calls `schedule_overlap_bucketing` with `collective_bucketing=True`, which
fuses small all_gathers / reduce_scatters into larger ones and reorders
around compute. Used as fw/bw compiler in the aot_eager path.
Calls ``schedule_overlap_bucketing`` with ``collective_bucketing=True``,
which fuses small all_gathers / reduce_scatters into larger ones and
reorders around compute. Used as fw/bw compiler in the aot_eager path.
"""

schedule_overlap_bucketing(gm, collective_bucketing=True)
gm.recompile()
return gm


def transformer_block_bucketing_reordering_pass(
gm: torch.fx.GraphModule,
example_inputs: Any | None = None,
*,
fsdp_manual_buckets: list[list[str] | str],
) -> torch.fx.GraphModule:
"""Manual aten-level bucketing and reordering per transformer block.

Uses ``manual_overlap_bucketing`` (pytorch/pytorch#165487) which
groups collectives by module FQN bucket and emits prefetch edges.
Requires the model to have been annotated with ``annotate_module_fqns``
before tracing so nodes carry ``module_fqn`` metadata.
"""
manual_overlap_bucketing(gm, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=False)
gm.recompile()
return gm


def get_simple_fsdp_compile_backend(
*, fsdp_reshard_after_forward: bool, auto_bucketing: bool, backend: str
*,
fsdp_reshard_after_forward: bool,
bucketing_mode: str = "none",
fsdp_manual_buckets: list[list[str] | str] | None = None,
async_tensor_parallel: bool = False,
backend: str,
) -> Callable:
"""Build a torch.compile backend that wires SimpleFSDP-aware passes.

- `fsdp_reshard_after_forward`: install `annotate_fsdp_all_gather` as the
joint_custom_pass so AG nodes get AC-recompute tags before the fwd/bwd
partition runs.
- `auto_bucketing`: enable inductor's distributed-collective bucketing
scheduler. Conflicts with `reorder_for_compute_comm_overlap`, so flips
that off; uses `post_grad_custom_post_pass` (inductor) or `aot_autograd`
with the bucketing pass as fw/bw compiler (aot_eager).
Args:
fsdp_reshard_after_forward: Install ``annotate_fsdp_all_gather`` as the
joint_custom_pass so AG nodes get AC-recompute tags before the
fwd/bwd partition runs.
bucketing_mode: Collective bucketing strategy.
- ``"none"``: no bucketing (default).
- ``"auto"``: global auto-bucketing via ``schedule_overlap_bucketing``.
- ``"transformer_block"``: per-layer manual bucketing via
``manual_overlap_bucketing``. Requires ``fsdp_manual_buckets``
and that the model was annotated with ``annotate_module_fqns``
before compilation.
fsdp_manual_buckets: Module FQN bucket plans for ``"transformer_block"``
mode. Each entry is a module FQN string or a list of FQN strings
that should be grouped into one bucket.
async_tensor_parallel: Apply ``async_tensor_parallel_pass`` to fuse TP
collectives with matmuls via symmetric memory.
backend: Underlying torch.compile backend (``"inductor"``,
``"aot_eager"``, ``"neuron"``, …).
"""
if bucketing_mode not in ("none", "auto", "transformer_block"):
raise ValueError(f"bucketing_mode must be 'none', 'auto', or 'transformer_block', got {bucketing_mode!r}")
if bucketing_mode == "transformer_block" and not fsdp_manual_buckets:
raise ValueError("fsdp_manual_buckets must be provided when bucketing_mode='transformer_block'")

torch._dynamo.config.capture_scalar_outputs = True
inner_backend = torch._dynamo.lookup_backend(backend)

if auto_bucketing:
if bucketing_mode == "auto":
if backend == "aot_eager":
aten_distributed_optimizations.insert_overlap_deps = False

Expand All @@ -123,23 +216,45 @@ def get_simple_fsdp_compile_backend(
bw_compiler=autobucketing_reordering_pass,
keep_inference_input_mutations=True,
)
elif backend in ["inductor", "neuron"]:
elif backend in ("inductor", "neuron"):

def _inductor_autobucketing_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule:
return schedule_overlap_bucketing_from_inductor_configs(gm.owning_module)
return schedule_overlap_bucketing(gm.owning_module, collective_bucketing=True)
Comment on lines 221 to +222

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In Inductor, post_grad_custom_post_pass receives a GraphModule as its argument, not a Graph. Since GraphModule does not have an owning_module attribute, gm.owning_module will raise an AttributeError. You should pass gm directly to schedule_overlap_bucketing.

Suggested change
def _inductor_autobucketing_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule:
return schedule_overlap_bucketing_from_inductor_configs(gm.owning_module)
return schedule_overlap_bucketing(gm.owning_module, collective_bucketing=True)
def _inductor_autobucketing_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
return schedule_overlap_bucketing(gm, collective_bucketing=True)


aten_distributed_optimizations.collective_bucketing = True
aten_distributed_optimizations.insert_overlap_deps = True

torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
torch._inductor.config.post_grad_custom_post_pass = _inductor_autobucketing_pass
else:
raise ValueError(f"Unsupported backend {backend} for auto_bucketing")
raise ValueError(f"Unsupported backend {backend!r} for bucketing_mode='auto'")
elif bucketing_mode == "transformer_block":
if backend == "aot_eager":
_tb_pass = partial(transformer_block_bucketing_reordering_pass, fsdp_manual_buckets=fsdp_manual_buckets)

inner_backend = aot_autograd_backend(
fw_compiler=_tb_pass, bw_compiler=_tb_pass, keep_inference_input_mutations=True
)
elif backend in ("inductor", "neuron"):

def _inductor_tb_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule:
return manual_overlap_bucketing(
gm.owning_module, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True
)
Comment on lines +241 to +244

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the autobucketing pass, gm here is a GraphModule. Accessing gm.owning_module will fail. Pass gm directly to manual_overlap_bucketing.

Suggested change
def _inductor_tb_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule:
return manual_overlap_bucketing(
gm.owning_module, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True
)
def _inductor_tb_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
return manual_overlap_bucketing(
gm, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True
)


torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
torch._inductor.config.post_grad_custom_post_pass = _inductor_tb_pass
else:
raise ValueError(f"Unsupported backend {backend!r} for bucketing_mode='transformer_block'")

def _joint_ac_pass(gm: torch.fx.GraphModule, example_inputs: Any) -> torch.fx.GraphModule:
return fsdp_reshard_after_fwd_pass(gm, example_inputs, reshard_after_forward=fsdp_reshard_after_forward)
gm = fsdp_reshard_after_fwd_pass(gm, example_inputs, reshard_after_forward=fsdp_reshard_after_forward)
if async_tensor_parallel:
gm = async_tensor_parallel_pass(gm, example_inputs)
return gm

def _backend_with_passes(*args, **kwargs):
with functorch_config.patch("joint_custom_pass", _joint_ac_pass):
Expand Down
3 changes: 1 addition & 2 deletions lm_engine/parallel/simple_fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def forward(self, x: DTensor) -> torch.Tensor:
if not _active_parametrization:
return x

output = self.replicate_compute(x)
return output
return self.replicate_compute(x)


def data_parallel(
Expand Down
Loading