[ROCm][Kernel] W4A16 prefill: optimize dequant#985
Open
mgehre-amd wants to merge 1 commit into
Open
Conversation
e6aad07 to
b04ceaf
Compare
b04ceaf to
d11fb3e
Compare
mgehre-amd
commented
Jun 8, 2026
mgehre-amd
commented
Jun 8, 2026
mgehre-amd
commented
Jun 8, 2026
d11fb3e to
aeeea16
Compare
… carrier Speeds up the Triton W4A16 skinny prefill GEMM on gfx11x (RDNA3/3.5). The prefill dequant inner loop was VALU-issue-bound; this cuts the dequant instruction count and, for asymmetric layers, folds the per-group scale + zero-point into a single load. ~14-15% lower TTFT on a W4A16 model with no accuracy change. Pure Python/Triton; no C++/HIP rebuild. Changes: - Packed int4->fp16 dequant (_i4_and_or_magic): one v_and_or_b32 dequants two nibbles into fp16 (1024+n) per instruction (the i4_to_half magic trick), replacing the scalar v_and_b16/v_or_b16 pair. The kernel selects the packed path itself at JIT time -- fp16 AND RDNA gfx11/gfx12, via a tl.target_info.constexpr_function (_target_is_gfx1x); no host flag. Everything else uses the scalar unpack. - Distilled gfx11 tile table (_select_skinny_gfx11_config): BLOCK_N=256, num_warps=8, BLOCK_K=32, BLOCK_M 64/128 by K and N. Other arches unchanged. - Asymmetric layers: packed_scale_zp carrier (one fp32 per (n, group)) folds the per-group scale and zero-point offset into a single load. fp16 packs scale|bias_eff and consumes it with one v_pk_fma_f16; bf16 packs scale|zp_int and keeps the int-domain subtract (RDNA3 has no v_pk_fma_bf16). Materialised only for asym layers; the kernel's HAS_ZP constexpr selects carrier vs scales. - Symmetric layers keep a dedicated fast path (scales + constant -8 offset): sym has no second load to fold, so the carrier would be pure overhead (measured ~+8% on fp16 sym, up to +22% on o_proj). - perf test: exercises the carrier for asym providers, adds Qwen3-1.7B and Gemma3-4B prefill shapes; gfx1151 golden regenerated. Measured on gfx1151 (Qwen3-4B-AWQ, fp16, input 3968 / output 1): TTFT 1436 -> 1234 ms (-14.1%). bf16 asym carrier -4.8% (do_bench). Carrier dequant verified numerically (fp16 <1e-3 rel error; bf16 bit-identical). Verified: tests/kernels/quantization/test_hybrid_w4a16_perf.py 111/112 pass -- the one failure is a pre-existing flaky tiny-M wvSplitK_int4 decode cell (unchanged HIP kernel), unrelated to this change. ruff + mypy clean. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
aeeea16 to
f9dcc75
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Speeds up the Triton W4A16 skinny prefill GEMM on gfx11x (RDNA3/3.5). The
prefill dequant inner loop is VALU-issue-bound; this change cuts the dequant
instruction count and, for asymmetric layers, folds the per-group scale and
zero-point into a single load. ~14–15% lower prefill TTFT on a W4A16 model with
no accuracy change. Pure Python/Triton — no C++/HIP rebuild.
What it does
vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py:Packed int4→fp16 dequant (
_i4_and_or_magic): onev_and_or_b32dequants two int4 nibbles into fp16
1024+nper instruction (thei4_to_halfmagic-constant trick), replacing the scalarv_and_b16/v_or_b16pair. The kernel selects this path at JIT time — fp16 and RDNAgfx11/gfx12, via a
tl.target_info.constexpr_function(_target_is_gfx1x);no host flag is passed. All other cases use the scalar unpack.
Distilled gfx11 tile table (
_select_skinny_gfx11_config):BLOCK_N=256,num_warps=8,BLOCK_K=32, withBLOCK_M64/128 chosen by K and N. Otherarchitectures keep their existing config.
Asymmetric layers —
packed_scale_zpcarrier: one fp32 per(n, group)packs the per-group scale and zero-point offset, so the prefill kernel does a
single load instead of separate scale + zero-point loads. fp16 packs
scale | bias_effand consumes it with onev_pk_fma_f16; bf16 packsscale | zp_intand keeps an int-domain subtract (RDNA3 has nov_pk_fma_bf16). The carrier is materialised at load time only forasymmetric layers; the kernel’s
HAS_ZPconstexpr selects carrier vs scales.Symmetric layers — dedicated fast path: the −8 offset is a constant, so
there is no second load to fold and a carrier would be pure overhead. Sym
layers read the scale directly and subtract the constant.
The decode path (
wvSplitK_int4_g) is unchanged.Benchmark
End-to-end median TTFT on gfx1151 across the W4A16 model suite, each model run
with its own reproducer (input length / dataset per model). BASE = origin/gfx11
kernel, BRANCH = this change (same build; only the prefill kernel file differs):
Median ≈ −15% TTFT (range −12% to −21%; Qwen3-1.7B is flat / within noise).
At the kernel level the bf16 asymmetric carrier is −4.8% (
do_bench, M=2048),and the Qwen3-4B
gate_upGEMM (K=2560, N=19456, g=128) reaches ~32 TFLOP/s fp16at M=2048–4096.
Standalone GEMM throughput (fp16, M=2048, gfx1151)
Kernel-isolated
do_benchof the W4A16 prefill GEMM (asymmetric / AWQ),mean of two runs; BASE = origin/gfx11 kernel, BRANCH = this change:
Median ≈ +24% GEMM throughput. The K=4096 square shapes (Qwen3-8B qkv/o)
are at parity — base and branch select the same tile config there. The bf16
asymmetric carrier is −4.8% at the kernel level (
do_bench, M=2048).Testing
Numerical correctness (gfx1151): the asymmetric carrier dequant matches a
reference to
<1e-3relative error in fp16 and is bit-identical in bf16, atM=1 and M=2048 across the Qwen3-4B projections; the symmetric fast path is
numerically consistent with the carrier (bf16 bit-identical, fp16 ≤1 ULP).
Perf regression test:
Covers fp16/bf16 × symmetric/asymmetric across the Qwen / Gemma / Llama
prefill shape catalog (incl. added Qwen3-1.7B and Gemma3-4B shapes). The
prefill (
hybrid_triton_w4a16) baselines are regenerated for gfx1151; thedecode (
wvSplitK_int4_g) baselines are unchanged (the kernel is untouched).Prefill cells pass within the tolerance band. A tiny-M (M≤4) decode cell can
intermittently exceed the ±8% band due to run-to-run variance on the unchanged
HIP decode kernel — not affected by this change.
Lint:
ruff check,ruff format, andmypypass (full pre-commit suitegreen).
Not a duplicate
This tunes the existing in-tree
HybridW4A16LinearKernelTriton prefill path onthe
gfx11branch and adds the scale/zp dequant carrier. It does not overlapwith the Marlin or CK W4A16 paths.
AI assistance
AI assistance (Claude) was used to author this change. The diff and benchmarks
were reviewed and run by the submitter.
Files changed (3):
vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.pytests/kernels/quantization/test_hybrid_w4a16_perf.pytests/kernels/quantization/golden/hybrid_w4a16_gfx1151.json