-
Notifications
You must be signed in to change notification settings - Fork 30
[FSDP] improve compiler passes #424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9f4d96a
d61a1ae
440cdca
9bcca80
36b3a85
c842e8c
9264a9b
b373bbe
019f3d3
95187b8
69a35a2
d066c7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,25 +4,77 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| # Adapted from torchtitan's graph_trainer JIT compile pipeline: | ||||||||||||||||||
| # https://github.com/pytorch/torchtitan/tree/main/torchtitan/experiments/graph_trainer | ||||||||||||||||||
| # Subset: auto-bucketing + reshard-after-forward joint pass for SimpleFSDP. | ||||||||||||||||||
| # Original copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||||||||||
|
|
||||||||||||||||||
| from typing import Any, Callable | ||||||||||||||||||
| import warnings | ||||||||||||||||||
| from collections.abc import Callable | ||||||||||||||||||
| from functools import partial | ||||||||||||||||||
| from typing import Any | ||||||||||||||||||
|
|
||||||||||||||||||
| import torch | ||||||||||||||||||
| import torch._functorch.config as functorch_config | ||||||||||||||||||
| from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | ||||||||||||||||||
| from torch._inductor.config import aten_distributed_optimizations | ||||||||||||||||||
| from torch._inductor.fx_passes.bucketing import is_all_gather_into_tensor, is_wait_tensor | ||||||||||||||||||
| from torch._inductor.fx_passes.overlap_scheduling import ( | ||||||||||||||||||
| schedule_overlap_bucketing, | ||||||||||||||||||
| schedule_overlap_bucketing_from_inductor_configs, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from torch._inductor.fx_passes.micro_pipeline_tp import micro_pipeline_tp_pass | ||||||||||||||||||
| from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing | ||||||||||||||||||
| from torch._inductor.fx_passes.overlap_scheduling import get_group_name, schedule_overlap_bucketing | ||||||||||||||||||
| from torch.distributed._symmetric_memory import enable_symm_mem_for_group | ||||||||||||||||||
| from torch.utils.checkpoint import CheckpointPolicy | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def async_tensor_parallel_pass(gm: torch.fx.GraphModule, example_inputs: tuple) -> torch.fx.GraphModule: | ||||||||||||||||||
| """Pipeline TP collectives with matmuls via symmetric memory. | ||||||||||||||||||
|
|
||||||||||||||||||
| Fuses all-gather + matmul into ``symm_mem.fused_all_gather_matmul`` and | ||||||||||||||||||
| matmul + reduce-scatter into ``symm_mem.fused_matmul_reduce_scatter``. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| c10d = torch.ops._c10d_functional | ||||||||||||||||||
| collective_targets = { | ||||||||||||||||||
| c10d.all_gather_into_tensor.default, | ||||||||||||||||||
| c10d.reduce_scatter_tensor.default, | ||||||||||||||||||
| } | ||||||||||||||||||
| registered: set[str] = set() | ||||||||||||||||||
| for node in gm.graph.nodes: | ||||||||||||||||||
| if node.target not in collective_targets: | ||||||||||||||||||
| continue | ||||||||||||||||||
| pg = get_group_name(node) | ||||||||||||||||||
| if pg not in registered: | ||||||||||||||||||
| registered.add(pg) | ||||||||||||||||||
| with warnings.catch_warnings(): | ||||||||||||||||||
| warnings.simplefilter("ignore", FutureWarning) | ||||||||||||||||||
| enable_symm_mem_for_group(pg) | ||||||||||||||||||
|
|
||||||||||||||||||
| micro_pipeline_tp_pass(gm.graph) | ||||||||||||||||||
| gm.graph.lint() | ||||||||||||||||||
| gm.recompile() | ||||||||||||||||||
| return gm | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def reassign_to_pg_pass( | ||||||||||||||||||
| gm: torch.fx.GraphModule, | ||||||||||||||||||
| example_inputs: Any = None, | ||||||||||||||||||
| *, | ||||||||||||||||||
| source_pg_name: str, | ||||||||||||||||||
| target_pg_name: str, | ||||||||||||||||||
| ) -> torch.fx.GraphModule: | ||||||||||||||||||
| """Rewrite all-gather nodes from ``source_pg_name`` to ``target_pg_name``. | ||||||||||||||||||
|
|
||||||||||||||||||
| Must be applied **before** bucketing passes so bucketed AGs inherit the | ||||||||||||||||||
| new PG. | ||||||||||||||||||
| """ | ||||||||||||||||||
| count = 0 | ||||||||||||||||||
| for node in gm.graph.nodes: | ||||||||||||||||||
| if is_all_gather_into_tensor(node) and node.args[2] == source_pg_name: | ||||||||||||||||||
| node.args = (node.args[0], node.args[1], target_pg_name) | ||||||||||||||||||
| count += 1 | ||||||||||||||||||
| gm.recompile() | ||||||||||||||||||
| return gm | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: | ||||||||||||||||||
| """True if `node` is the wait_tensor of an FSDP all_gather that can be prefetched. | ||||||||||||||||||
| """True if ``node`` is the wait_tensor of an FSDP all_gather that can be prefetched. | ||||||||||||||||||
|
|
||||||||||||||||||
| Mirrors torchtitan's heuristic: walk back from the wait through chains of | ||||||||||||||||||
| single-input ops to a graph placeholder. If reachable, the all_gather is | ||||||||||||||||||
|
|
@@ -40,16 +92,18 @@ def _is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: | |||||||||||||||||
| def annotate_fsdp_all_gather(gm: torch.fx.GraphModule, reshard_after_forward: bool) -> torch.fx.GraphModule: | ||||||||||||||||||
| """Tag SimpleFSDP all_gather nodes for reshard-after-forward behavior. | ||||||||||||||||||
|
|
||||||||||||||||||
| When `reshard_after_forward=True`, all_gathers and their immediate consumers | ||||||||||||||||||
| (wait + post-wait slice + dtype convert) are marked MUST_RECOMPUTE, so the | ||||||||||||||||||
| backward re-runs them rather than holding gathered params live through fwd. | ||||||||||||||||||
| When ``reshard_after_forward=True``, all_gathers and their immediate | ||||||||||||||||||
| consumers (wait + post-wait slice + dtype convert) are marked | ||||||||||||||||||
| MUST_RECOMPUTE so the backward re-runs them rather than holding gathered | ||||||||||||||||||
| params live through fwd. | ||||||||||||||||||
|
|
||||||||||||||||||
| When False, they are marked MUST_SAVE (params stay all-gathered). | ||||||||||||||||||
|
|
||||||||||||||||||
| `ac_graph_id=100000` ensures the recompute decision is not influenced by | ||||||||||||||||||
| neighboring AC regions (a partitioner workaround). | ||||||||||||||||||
| ``ac_graph_id=100000`` prevents the recompute decision from being | ||||||||||||||||||
| influenced by neighbouring AC regions (partitioner workaround). | ||||||||||||||||||
|
|
||||||||||||||||||
| Run as a `joint_custom_pass` so AC tags survive the joint -> fwd/bwd partition. | ||||||||||||||||||
| Run as a ``joint_custom_pass`` so AC tags survive the joint->fwd/bwd | ||||||||||||||||||
| partition. | ||||||||||||||||||
| """ | ||||||||||||||||||
| graph = gm.graph | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -89,32 +143,71 @@ def fsdp_reshard_after_fwd_pass( | |||||||||||||||||
| def autobucketing_reordering_pass(gm: torch.fx.GraphModule, example_inputs: Any | None = None) -> torch.fx.GraphModule: | ||||||||||||||||||
| """Bucket and reorder collectives for compute/comm overlap. | ||||||||||||||||||
|
|
||||||||||||||||||
| Calls `schedule_overlap_bucketing` with `collective_bucketing=True`, which | ||||||||||||||||||
| fuses small all_gathers / reduce_scatters into larger ones and reorders | ||||||||||||||||||
| around compute. Used as fw/bw compiler in the aot_eager path. | ||||||||||||||||||
| Calls ``schedule_overlap_bucketing`` with ``collective_bucketing=True``, | ||||||||||||||||||
| which fuses small all_gathers / reduce_scatters into larger ones and | ||||||||||||||||||
| reorders around compute. Used as fw/bw compiler in the aot_eager path. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| schedule_overlap_bucketing(gm, collective_bucketing=True) | ||||||||||||||||||
| gm.recompile() | ||||||||||||||||||
| return gm | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def transformer_block_bucketing_reordering_pass( | ||||||||||||||||||
| gm: torch.fx.GraphModule, | ||||||||||||||||||
| example_inputs: Any | None = None, | ||||||||||||||||||
| *, | ||||||||||||||||||
| fsdp_manual_buckets: list[list[str] | str], | ||||||||||||||||||
| ) -> torch.fx.GraphModule: | ||||||||||||||||||
| """Manual aten-level bucketing and reordering per transformer block. | ||||||||||||||||||
|
|
||||||||||||||||||
| Uses ``manual_overlap_bucketing`` (pytorch/pytorch#165487) which | ||||||||||||||||||
| groups collectives by module FQN bucket and emits prefetch edges. | ||||||||||||||||||
| Requires the model to have been annotated with ``annotate_module_fqns`` | ||||||||||||||||||
| before tracing so nodes carry ``module_fqn`` metadata. | ||||||||||||||||||
| """ | ||||||||||||||||||
| manual_overlap_bucketing(gm, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=False) | ||||||||||||||||||
| gm.recompile() | ||||||||||||||||||
| return gm | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def get_simple_fsdp_compile_backend( | ||||||||||||||||||
| *, fsdp_reshard_after_forward: bool, auto_bucketing: bool, backend: str | ||||||||||||||||||
| *, | ||||||||||||||||||
| fsdp_reshard_after_forward: bool, | ||||||||||||||||||
| bucketing_mode: str = "none", | ||||||||||||||||||
| fsdp_manual_buckets: list[list[str] | str] | None = None, | ||||||||||||||||||
| async_tensor_parallel: bool = False, | ||||||||||||||||||
| backend: str, | ||||||||||||||||||
| ) -> Callable: | ||||||||||||||||||
| """Build a torch.compile backend that wires SimpleFSDP-aware passes. | ||||||||||||||||||
|
|
||||||||||||||||||
| - `fsdp_reshard_after_forward`: install `annotate_fsdp_all_gather` as the | ||||||||||||||||||
| joint_custom_pass so AG nodes get AC-recompute tags before the fwd/bwd | ||||||||||||||||||
| partition runs. | ||||||||||||||||||
| - `auto_bucketing`: enable inductor's distributed-collective bucketing | ||||||||||||||||||
| scheduler. Conflicts with `reorder_for_compute_comm_overlap`, so flips | ||||||||||||||||||
| that off; uses `post_grad_custom_post_pass` (inductor) or `aot_autograd` | ||||||||||||||||||
| with the bucketing pass as fw/bw compiler (aot_eager). | ||||||||||||||||||
| Args: | ||||||||||||||||||
| fsdp_reshard_after_forward: Install ``annotate_fsdp_all_gather`` as the | ||||||||||||||||||
| joint_custom_pass so AG nodes get AC-recompute tags before the | ||||||||||||||||||
| fwd/bwd partition runs. | ||||||||||||||||||
| bucketing_mode: Collective bucketing strategy. | ||||||||||||||||||
| - ``"none"``: no bucketing (default). | ||||||||||||||||||
| - ``"auto"``: global auto-bucketing via ``schedule_overlap_bucketing``. | ||||||||||||||||||
| - ``"transformer_block"``: per-layer manual bucketing via | ||||||||||||||||||
| ``manual_overlap_bucketing``. Requires ``fsdp_manual_buckets`` | ||||||||||||||||||
| and that the model was annotated with ``annotate_module_fqns`` | ||||||||||||||||||
| before compilation. | ||||||||||||||||||
| fsdp_manual_buckets: Module FQN bucket plans for ``"transformer_block"`` | ||||||||||||||||||
| mode. Each entry is a module FQN string or a list of FQN strings | ||||||||||||||||||
| that should be grouped into one bucket. | ||||||||||||||||||
| async_tensor_parallel: Apply ``async_tensor_parallel_pass`` to fuse TP | ||||||||||||||||||
| collectives with matmuls via symmetric memory. | ||||||||||||||||||
| backend: Underlying torch.compile backend (``"inductor"``, | ||||||||||||||||||
| ``"aot_eager"``, ``"neuron"``, …). | ||||||||||||||||||
| """ | ||||||||||||||||||
| if bucketing_mode not in ("none", "auto", "transformer_block"): | ||||||||||||||||||
| raise ValueError(f"bucketing_mode must be 'none', 'auto', or 'transformer_block', got {bucketing_mode!r}") | ||||||||||||||||||
| if bucketing_mode == "transformer_block" and not fsdp_manual_buckets: | ||||||||||||||||||
| raise ValueError("fsdp_manual_buckets must be provided when bucketing_mode='transformer_block'") | ||||||||||||||||||
|
|
||||||||||||||||||
| torch._dynamo.config.capture_scalar_outputs = True | ||||||||||||||||||
| inner_backend = torch._dynamo.lookup_backend(backend) | ||||||||||||||||||
|
|
||||||||||||||||||
| if auto_bucketing: | ||||||||||||||||||
| if bucketing_mode == "auto": | ||||||||||||||||||
| if backend == "aot_eager": | ||||||||||||||||||
| aten_distributed_optimizations.insert_overlap_deps = False | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -123,23 +216,45 @@ def get_simple_fsdp_compile_backend( | |||||||||||||||||
| bw_compiler=autobucketing_reordering_pass, | ||||||||||||||||||
| keep_inference_input_mutations=True, | ||||||||||||||||||
| ) | ||||||||||||||||||
| elif backend in ["inductor", "neuron"]: | ||||||||||||||||||
| elif backend in ("inductor", "neuron"): | ||||||||||||||||||
|
|
||||||||||||||||||
| def _inductor_autobucketing_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule: | ||||||||||||||||||
| return schedule_overlap_bucketing_from_inductor_configs(gm.owning_module) | ||||||||||||||||||
| return schedule_overlap_bucketing(gm.owning_module, collective_bucketing=True) | ||||||||||||||||||
|
Comment on lines
221
to
+222
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Inductor,
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| aten_distributed_optimizations.collective_bucketing = True | ||||||||||||||||||
| aten_distributed_optimizations.insert_overlap_deps = True | ||||||||||||||||||
|
|
||||||||||||||||||
| torch._inductor.config.allow_buffer_reuse = False | ||||||||||||||||||
| torch._inductor.config.reorder_for_peak_memory = False | ||||||||||||||||||
| torch._inductor.config.reorder_for_compute_comm_overlap = False | ||||||||||||||||||
| torch._inductor.config.post_grad_custom_post_pass = _inductor_autobucketing_pass | ||||||||||||||||||
| else: | ||||||||||||||||||
| raise ValueError(f"Unsupported backend {backend} for auto_bucketing") | ||||||||||||||||||
| raise ValueError(f"Unsupported backend {backend!r} for bucketing_mode='auto'") | ||||||||||||||||||
| elif bucketing_mode == "transformer_block": | ||||||||||||||||||
| if backend == "aot_eager": | ||||||||||||||||||
| _tb_pass = partial(transformer_block_bucketing_reordering_pass, fsdp_manual_buckets=fsdp_manual_buckets) | ||||||||||||||||||
|
|
||||||||||||||||||
| inner_backend = aot_autograd_backend( | ||||||||||||||||||
| fw_compiler=_tb_pass, bw_compiler=_tb_pass, keep_inference_input_mutations=True | ||||||||||||||||||
| ) | ||||||||||||||||||
| elif backend in ("inductor", "neuron"): | ||||||||||||||||||
|
|
||||||||||||||||||
| def _inductor_tb_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule: | ||||||||||||||||||
| return manual_overlap_bucketing( | ||||||||||||||||||
| gm.owning_module, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True | ||||||||||||||||||
| ) | ||||||||||||||||||
|
Comment on lines
+241
to
+244
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the autobucketing pass,
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| torch._inductor.config.allow_buffer_reuse = False | ||||||||||||||||||
| torch._inductor.config.reorder_for_peak_memory = False | ||||||||||||||||||
| torch._inductor.config.reorder_for_compute_comm_overlap = False | ||||||||||||||||||
| torch._inductor.config.post_grad_custom_post_pass = _inductor_tb_pass | ||||||||||||||||||
| else: | ||||||||||||||||||
| raise ValueError(f"Unsupported backend {backend!r} for bucketing_mode='transformer_block'") | ||||||||||||||||||
|
|
||||||||||||||||||
| def _joint_ac_pass(gm: torch.fx.GraphModule, example_inputs: Any) -> torch.fx.GraphModule: | ||||||||||||||||||
| return fsdp_reshard_after_fwd_pass(gm, example_inputs, reshard_after_forward=fsdp_reshard_after_forward) | ||||||||||||||||||
| gm = fsdp_reshard_after_fwd_pass(gm, example_inputs, reshard_after_forward=fsdp_reshard_after_forward) | ||||||||||||||||||
| if async_tensor_parallel: | ||||||||||||||||||
| gm = async_tensor_parallel_pass(gm, example_inputs) | ||||||||||||||||||
| return gm | ||||||||||||||||||
|
|
||||||||||||||||||
| def _backend_with_passes(*args, **kwargs): | ||||||||||||||||||
| with functorch_config.patch("joint_custom_pass", _joint_ac_pass): | ||||||||||||||||||
|
|
||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.