Skip to content

gfx950 MoE A8W4: tuned entries for gpt-oss shapes + fallback hardening#3580

Open
xiaohuguo2023 wants to merge 5 commits into
mainfrom
xguo/gptoss-a8w4-dispatch-additions
Open

gfx950 MoE A8W4: tuned entries for gpt-oss shapes + fallback hardening#3580
xiaohuguo2023 wants to merge 5 commits into
mainfrom
xguo/gptoss-a8w4-dispatch-additions

Conversation

@xiaohuguo2023
Copy link
Copy Markdown
Member

  • 15 tuned JSON entries added to gfx950-A8W4.json for the
    (block_m, N, K) shapes hit by gpt-oss-120b W4A8 at TP={1,2,4}.
    All use the same tuning:
    BLOCK_N=128, BLOCK_K=256, num_stages=2, num_warps=4

  • Proxy fallback skips BLOCK_K < 256 entries. The JSON has
    a handful of small-BLOCK_K entries for Qwen3/Llama4 shapes
    that are fine for their own exact lookup, but if borrowed by a
    caller passing swizzle_mx_scale="CDNA4_SCALE" the CDNA4
    unswizzle kernel can't compile them and crashes.

  • gfx950 heuristic picks num_stages via pick_gemm_num_stages
    instead of hardcoding ns=1. Same helper moe_op_gemm_a8w8.py
    uses — returns ns=2 when the tile fits in LDS, else ns=1.
    block_m == 64 keeps its rocprof-tuned ns=1.

…stunes

get_kernel_config_triton() falls back to an arch-heuristic when a (bm, N, K)
lookup misses gfx950-A8W4.json. For 15 (bm, N, K) tuples exercised by
gpt-oss-120b W4A8 at TP={1,2,4}, the heuristic picks BLOCK_SIZE_N in {256,512}
with num_stages=1, while aiter 0.1.13 (via gfx950-MOE-MX_FP4_A8.json) and
direct micro-tuning both prefer BLOCK_SIZE_N=128, num_stages=2 for the same
shapes.

Adds entries for bm{32,64,128} x N{1536,3072,6144} x K{768,1536,3072} (skipping
those already present), all tuned to:
  BLOCK_SIZE_K=256, BLOCK_SIZE_N=128, matrix_instr_nonkdim=16,
  num_stages=2, num_warps=4, waves_per_eu=0

Config probe on gpt-oss-120b W4A8 (TP=1/CONC=32/ISL=1024/OSL=1024) shows
106 of 210 shared (M, N, K, bm) shapes pick a different config between
aiter 0.1.13 and aiter HEAD; all 106 divergences trace back to these 15
missing entries.
  - Proxy fallback now skips BLOCK_K<256 (CDNA4 unswizzle won't compile).
  - gfx950 heuristic uses pick_gemm_num_stages instead of hardcoded ns=1.
@xiaohuguo2023 xiaohuguo2023 requested review from a team and Copilot June 6, 2026 23:21
@xiaohuguo2023 xiaohuguo2023 requested review from brunomazzottiamd and lburzawa and removed request for Copilot June 6, 2026 23:21
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 6, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3580 --add-label <label>

@Rohan138
Copy link
Copy Markdown
Contributor

Rohan138 commented Jun 8, 2026

Validated against gpt-oss-120b W4A8 (amd/gpt-oss120b-w-mxfp4-a-fp8) at TP=8 on MI355X with vllm@nightly. The proxy guard + new tuned entries work for shapes where lookup misses, but the kernel still crashes for one shape vLLM dispatches that has a problematic direct tuned entry.

The dispatched gemm shapes for TP=8 are:

block_m=16  n=2880 k=360    <- exact-lookup hits bm16_n2880_k360, BLOCK_K=128 -> crash
block_m=16  n=720  k=2880   <- safe (BLOCK_K=512)
block_m=32  n=3072 k=512    <- no entry, no proxy, falls to heuristic, safe (BLOCK_K=256)
block_m=32  n=1024 k=3072   <- no entry, no proxy, falls to heuristic, safe (BLOCK_K=256)
block_m=128 n=3072 k=512    <- no entry, no proxy, falls to heuristic, safe
block_m=128 n=1024 k=3072   <- no entry, no proxy, falls to heuristic, safe

bm16_n2880_k360 (BLOCK_K=128) crashes unswizzle_mx_scale_cdna4 exactly as before. The proxy guard in this PR doesn't help because the exact-lookup hits first.

The n=2880, k=360 shape is the unpadded gpt-oss-120b TP=8 down-projection. vLLM's aiter_mxfp4_w4a8_moe.py passes unpadded_N / unpadded_K to moe_gemm_a8w4, which at block_m == 16 overrides the (padded) N, K from the weight tensor with the unpadded values before the kernel-config lookup. That's intentional for output sizing, but it means tuned entries are keyed on unpadded sizes.

Two options to close this:

  1. Clamp on the direct lookup too, not just the proxy: block_k = max(tuned["BLOCK_SIZE_K"], 256) when the caller will use CDNA4 swizzle. Same shape continues to use the tuner-picked block_n/num_warps, just clamps the K tile.
  2. Add a bm=16 tuned entry for n=2880 k=360 with BLOCK_K ≥ 256 (and similar for other TP=8 down-projection shapes). Will also need n=2880 k=720, n=2880 k=1440 for TP=4/TP=2.

Option 1 is one-line; option 2 needs a tuning sweep.

Note: vllm-project/vllm#44804 (Xiaohu's vLLM-side fix) sidesteps this by gating CDNA4 swizzle off at TP≥4, so the kernel doesn't take the unswizzle path at all for the affected shapes. With #44804 applied I verified W4A8 TP=8 runs cleanly (mc=1=326 tok/s, mc=8=2087 tok/s on amd/gpt-oss120b-w-mxfp4-a-fp8). That makes the proxy guard in this PR a defense-in-depth rather than the critical fix for TP=8.

Copilot AI review requested due to automatic review settings June 8, 2026 12:06
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds/updates gfx950 MoE A8W4 tuning and fallback logic to improve performance on gpt-oss shapes while hardening CDNA4-scale swizzle behavior to avoid compilation crashes.

Changes:

  • Add 15 tuned (block_m, N, K) entries to gfx950-A8W4.json targeting gpt-oss-120b W4A8 shapes.
  • Harden untuned-shape proxy fallback to avoid borrowing tuned entries with small BLOCK_K that can’t compile under CDNA4 unswizzle.
  • Improve gfx950 heuristic selection by capping BLOCK_N and selecting num_stages via pick_gemm_num_stages, plus clamp block_k when CDNA4 swizzle is requested.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
aiter/ops/triton/moe/moe_op_gemm_a8w4.py Updates config selection heuristics (proxy filtering, gfx950 BLOCK_N cap, num_stages picking) and adds a CDNA4 swizzle block_k clamp.
aiter/ops/triton/configs/moe/gfx950-A8W4.json Adds tuned entries for gpt-oss shapes; also removes one existing bm16 tuned entry (see comment).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/triton/moe/moe_op_gemm_a8w4.py Outdated
Comment thread aiter/ops/triton/configs/moe/gfx950-A8W4.json
@xiaohuguo2023
Copy link
Copy Markdown
Member Author

xiaohuguo2023 commented Jun 8, 2026

Validated against gpt-oss-120b W4A8 (amd/gpt-oss120b-w-mxfp4-a-fp8) at TP=8 on MI355X with vllm@nightly. The proxy guard + new tuned entries work for shapes where lookup misses, but the kernel still crashes for one shape vLLM dispatches that has a problematic direct tuned entry.

The dispatched gemm shapes for TP=8 are:

block_m=16  n=2880 k=360    <- exact-lookup hits bm16_n2880_k360, BLOCK_K=128 -> crash
block_m=16  n=720  k=2880   <- safe (BLOCK_K=512)
block_m=32  n=3072 k=512    <- no entry, no proxy, falls to heuristic, safe (BLOCK_K=256)
block_m=32  n=1024 k=3072   <- no entry, no proxy, falls to heuristic, safe (BLOCK_K=256)
block_m=128 n=3072 k=512    <- no entry, no proxy, falls to heuristic, safe
block_m=128 n=1024 k=3072   <- no entry, no proxy, falls to heuristic, safe

bm16_n2880_k360 (BLOCK_K=128) crashes unswizzle_mx_scale_cdna4 exactly as before. The proxy guard in this PR doesn't help because the exact-lookup hits first.

The n=2880, k=360 shape is the unpadded gpt-oss-120b TP=8 down-projection. vLLM's aiter_mxfp4_w4a8_moe.py passes unpadded_N / unpadded_K to moe_gemm_a8w4, which at block_m == 16 overrides the (padded) N, K from the weight tensor with the unpadded values before the kernel-config lookup. That's intentional for output sizing, but it means tuned entries are keyed on unpadded sizes.

Two options to close this:

  1. Clamp on the direct lookup too, not just the proxy: block_k = max(tuned["BLOCK_SIZE_K"], 256) when the caller will use CDNA4 swizzle. Same shape continues to use the tuner-picked block_n/num_warps, just clamps the K tile.
  2. Add a bm=16 tuned entry for n=2880 k=360 with BLOCK_K ≥ 256 (and similar for other TP=8 down-projection shapes). Will also need n=2880 k=720, n=2880 k=1440 for TP=4/TP=2.

Option 1 is one-line; option 2 needs a tuning sweep.

Note: vllm-project/vllm#44804 (Xiaohu's vLLM-side fix) sidesteps this by gating CDNA4 swizzle off at TP≥4, so the kernel doesn't take the unswizzle path at all for the affected shapes. With #44804 applied I verified W4A8 TP=8 runs cleanly (mc=1=326 tok/s, mc=8=2087 tok/s on amd/gpt-oss120b-w-mxfp4-a-fp8). That makes the proxy guard in this PR a defense-in-depth rather than the critical fix for TP=8.

Thanks @Rohan138!

TP4/TP8, I have implemented as you suggested, clamp BLOCK_K>=256 when caller requests CDNA4
swizzle.

All these may need change once vllm upgrade from triton 3.6 to triton 3.7.

But vllm side change is purely for perf reason. significant difference to use swizzle off when TP > 4.

…ssion and use_async_padding=True for async-copy lowering on gfx950
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants