Skip to content

refactor: unify cudagraph-vs-eager dispatch via ForwardMode#1120

Open
JiaoliangYu wants to merge 1 commit into
ROCm:mainfrom
JiaoliangYu:fix/dp-attn-eager-slot-mapping-mismatch
Open

refactor: unify cudagraph-vs-eager dispatch via ForwardMode#1120
JiaoliangYu wants to merge 1 commit into
ROCm:mainfrom
JiaoliangYu:fix/dp-attn-eager-slot-mapping-mismatch

Conversation

@JiaoliangYu
Copy link
Copy Markdown
Contributor

prepare_inputs and run_model used to each derive eager-vs-cudagraph from the same four-OR chain and drifted under PR #930's dp_uniform_decode path: prepare_inputs rounded bs up to the next captured graph size while run_model took the eager branch with local-real input_ids, leaving slot_mapping at the rounded length. Downstream this tripped aiter's fused_qk_rope_reshape_and_cache assert (t_slot=48 > t=33).

Introduce ForwardMode (frozen dataclass) owning the decision and its data:

  • ForwardMode.decide() is the single rule (all four force-eager conditions live here).
  • prepare_inputs uses effective_bs to size attn_metadata.
  • run_model strict-asserts forward_mode is set, then dispatches off use_cudagraph.
  • assert_shape_contract() guards the slot_mapping <-> input_ids invariant (skips prefill / cudagraph internally so callers stay decision-free).

Motivation

Consolidate eager-vs-cudagraph dispatch and padding decisions into a single ForwardMode.

Technical Details

Test Plan

the same with ci on gpt-oss-120b

Test Result

flex-extract:0.8832 strict-extract: 0.3806

Submission Checklist

@JiaoliangYu JiaoliangYu force-pushed the fix/dp-attn-eager-slot-mapping-mismatch branch from 83d3b3b to fc36062 Compare June 7, 2026 06:27
Comment thread atom/utils/forward_context.py Outdated
is_prefill=True,
)

if enforce_eager or not dp_uniform_decode:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In decode, we must unify graph_bs in every ranks since moe all_gather https://github.com/ROCm/ATOM/blob/main/atom/model_ops/moe.py#L185, although dp_uniform_decode=False, it only means decode fallback to eager, but the graph_bs we need to use the max_tokens in each decode ranks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when dp_uniform_decode =False, we use all_gatherrv to deal with variable-length

prepare_inputs and run_model used to each derive eager-vs-cudagraph from
the same four-OR chain and drifted under PR ROCm#930's dp_uniform_decode path:
prepare_inputs rounded bs up to the next captured graph size while
run_model took the eager branch with local-real input_ids, leaving
slot_mapping at the rounded length. Downstream this tripped aiter's
fused_qk_rope_reshape_and_cache assert (t_slot=48 > t=33).

Introduce ForwardMode (frozen dataclass) owning the decision and its data:
- ForwardMode.decide() is the single rule (all four force-eager conditions live here).
- prepare_inputs uses effective_bs to size attn_metadata.
- run_model strict-asserts forward_mode is set, then dispatches off use_cudagraph.
- assert_shape_contract() guards the slot_mapping <-> input_ids invariant
  (skips prefill / cudagraph internally so callers stay decision-free).
@JiaoliangYu JiaoliangYu force-pushed the fix/dp-attn-eager-slot-mapping-mismatch branch from fc36062 to 6e2e1c7 Compare June 8, 2026 06:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants