gfx950 MoE A8W4: tuned entries for gpt-oss shapes + fallback hardening#3580
gfx950 MoE A8W4: tuned entries for gpt-oss shapes + fallback hardening#3580xiaohuguo2023 wants to merge 5 commits into
Conversation
…stunes
get_kernel_config_triton() falls back to an arch-heuristic when a (bm, N, K)
lookup misses gfx950-A8W4.json. For 15 (bm, N, K) tuples exercised by
gpt-oss-120b W4A8 at TP={1,2,4}, the heuristic picks BLOCK_SIZE_N in {256,512}
with num_stages=1, while aiter 0.1.13 (via gfx950-MOE-MX_FP4_A8.json) and
direct micro-tuning both prefer BLOCK_SIZE_N=128, num_stages=2 for the same
shapes.
Adds entries for bm{32,64,128} x N{1536,3072,6144} x K{768,1536,3072} (skipping
those already present), all tuned to:
BLOCK_SIZE_K=256, BLOCK_SIZE_N=128, matrix_instr_nonkdim=16,
num_stages=2, num_warps=4, waves_per_eu=0
Config probe on gpt-oss-120b W4A8 (TP=1/CONC=32/ISL=1024/OSL=1024) shows
106 of 210 shared (M, N, K, bm) shapes pick a different config between
aiter 0.1.13 and aiter HEAD; all 106 divergences trace back to these 15
missing entries.
- Proxy fallback now skips BLOCK_K<256 (CDNA4 unswizzle won't compile). - gfx950 heuristic uses pick_gemm_num_stages instead of hardcoded ns=1.
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
|
Validated against gpt-oss-120b W4A8 ( The dispatched gemm shapes for TP=8 are:
The Two options to close this:
Option 1 is one-line; option 2 needs a tuning sweep. Note: vllm-project/vllm#44804 (Xiaohu's vLLM-side fix) sidesteps this by gating CDNA4 swizzle off at TP≥4, so the kernel doesn't take the unswizzle path at all for the affected shapes. With #44804 applied I verified W4A8 TP=8 runs cleanly (mc=1=326 tok/s, mc=8=2087 tok/s on |
There was a problem hiding this comment.
Pull request overview
Adds/updates gfx950 MoE A8W4 tuning and fallback logic to improve performance on gpt-oss shapes while hardening CDNA4-scale swizzle behavior to avoid compilation crashes.
Changes:
- Add 15 tuned
(block_m, N, K)entries togfx950-A8W4.jsontargeting gpt-oss-120b W4A8 shapes. - Harden untuned-shape proxy fallback to avoid borrowing tuned entries with small
BLOCK_Kthat can’t compile under CDNA4 unswizzle. - Improve gfx950 heuristic selection by capping
BLOCK_Nand selectingnum_stagesviapick_gemm_num_stages, plus clampblock_kwhen CDNA4 swizzle is requested.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| aiter/ops/triton/moe/moe_op_gemm_a8w4.py | Updates config selection heuristics (proxy filtering, gfx950 BLOCK_N cap, num_stages picking) and adds a CDNA4 swizzle block_k clamp. |
| aiter/ops/triton/configs/moe/gfx950-A8W4.json | Adds tuned entries for gpt-oss shapes; also removes one existing bm16 tuned entry (see comment). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Thanks @Rohan138! TP4/TP8, I have implemented as you suggested, clamp BLOCK_K>=256 when caller requests CDNA4 All these may need change once vllm upgrade from triton 3.6 to triton 3.7. But vllm side change is purely for perf reason. significant difference to use swizzle off when TP > 4. |
…ssion and use_async_padding=True for async-copy lowering on gfx950
15 tuned JSON entries added to
gfx950-A8W4.jsonfor the(block_m, N, K)shapes hit by gpt-oss-120b W4A8 at TP={1,2,4}.All use the same tuning:
BLOCK_N=128, BLOCK_K=256, num_stages=2, num_warps=4Proxy fallback skips
BLOCK_K < 256entries. The JSON hasa handful of small-
BLOCK_Kentries for Qwen3/Llama4 shapesthat are fine for their own exact lookup, but if borrowed by a
caller passing
swizzle_mx_scale="CDNA4_SCALE"the CDNA4unswizzle kernel can't compile them and crashes.
gfx950 heuristic picks
num_stagesviapick_gemm_num_stagesinstead of hardcoding
ns=1. Same helpermoe_op_gemm_a8w8.pyuses — returns
ns=2when the tile fits in LDS, elsens=1.block_m == 64keeps its rocprof-tunedns=1.