Skip to content

[ROCm][Kernel] W4A16 prefill: optimize dequant#985

Open
mgehre-amd wants to merge 1 commit into
gfx11from
matthias.triton-w4a16-skinny-packedsb
Open

[ROCm][Kernel] W4A16 prefill: optimize dequant#985
mgehre-amd wants to merge 1 commit into
gfx11from
matthias.triton-w4a16-skinny-packedsb

Conversation

@mgehre-amd

@mgehre-amd mgehre-amd commented Jun 3, 2026

Copy link
Copy Markdown

Summary

Speeds up the Triton W4A16 skinny prefill GEMM on gfx11x (RDNA3/3.5). The
prefill dequant inner loop is VALU-issue-bound; this change cuts the dequant
instruction count and, for asymmetric layers, folds the per-group scale and
zero-point into a single load. ~14–15% lower prefill TTFT on a W4A16 model with
no accuracy change. Pure Python/Triton — no C++/HIP rebuild.

What it does

vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py:

  • Packed int4→fp16 dequant (_i4_and_or_magic): one v_and_or_b32
    dequants two int4 nibbles into fp16 1024+n per instruction (the
    i4_to_half magic-constant trick), replacing the scalar v_and_b16 /
    v_or_b16 pair. The kernel selects this path at JIT time — fp16 and RDNA
    gfx11/gfx12, via a tl.target_info.constexpr_function (_target_is_gfx1x);
    no host flag is passed. All other cases use the scalar unpack.

  • Distilled gfx11 tile table (_select_skinny_gfx11_config): BLOCK_N=256,
    num_warps=8, BLOCK_K=32, with BLOCK_M 64/128 chosen by K and N. Other
    architectures keep their existing config.

  • Asymmetric layers — packed_scale_zp carrier: one fp32 per (n, group)
    packs the per-group scale and zero-point offset, so the prefill kernel does a
    single load instead of separate scale + zero-point loads. fp16 packs
    scale | bias_eff and consumes it with one v_pk_fma_f16; bf16 packs
    scale | zp_int and keeps an int-domain subtract (RDNA3 has no
    v_pk_fma_bf16). The carrier is materialised at load time only for
    asymmetric layers; the kernel’s HAS_ZP constexpr selects carrier vs scales.

  • Symmetric layers — dedicated fast path: the −8 offset is a constant, so
    there is no second load to fold and a carrier would be pure overhead. Sym
    layers read the scale directly and subtract the constant.

The decode path (wvSplitK_int4_g) is unchanged.

Benchmark

End-to-end median TTFT on gfx1151 across the W4A16 model suite, each model run
with its own reproducer (input length / dataset per model). BASE = origin/gfx11
kernel, BRANCH = this change (same build; only the prefill kernel file differs):

model base branch Δ
Qwen2.5-7B-AWQ (in=3968) 2383.4 ms 1889.7 ms −20.7%
Qwen2.5-3B-AWQ (in=3968) 1166.2 ms 931.7 ms −20.1%
SmolLM2-1.7B-AWQ (in=8000) 1998.5 ms 1674.8 ms −16.2%
Llama-2-7B-AWQ (in=1920) 1129.3 ms 952.5 ms −15.7%
Qwen3-8B-AWQ (in=3968) 2925.7 ms 2472.5 ms −15.5%
Qwen3-8B compressed-tensors w4a16 (in=3968) 2926.2 ms 2478.2 ms −15.3%
Qwen3-4B-AWQ (in=3968) 1485.5 ms 1263.1 ms −15.0%
Qwen2.5-VL-7B-AWQ 1237.5 ms 1053.4 ms −14.9%
Qwen2.5-VL-3B-AWQ 774.8 ms 682.1 ms −12.0%
Qwen3-1.7B-AWQ (in=3968) 654.7 ms 660.8 ms +0.9%

Median ≈ −15% TTFT (range −12% to −21%; Qwen3-1.7B is flat / within noise).
At the kernel level the bf16 asymmetric carrier is −4.8% (do_bench, M=2048),
and the Qwen3-4B gate_up GEMM (K=2560, N=19456, g=128) reaches ~32 TFLOP/s fp16
at M=2048–4096.

Standalone GEMM throughput (fp16, M=2048, gfx1151)

Kernel-isolated do_bench of the W4A16 prefill GEMM (asymmetric / AWQ),
mean of two runs; BASE = origin/gfx11 kernel, BRANCH = this change:

projection K N base TFLOP/s branch TFLOP/s Δ
Qwen3-4B qkv 2560 3840 25.5 33.8 +32.7%
Qwen3-4B o 2560 2560 23.6 30.3 +28.7%
Qwen3-4B gate_up 2560 19456 24.3 30.4 +25.0%
Qwen3-4B down 9728 2560 21.6 31.7 +46.8%
Qwen3-8B qkv 4096 6144 19.6 19.5 −0.6%
Qwen3-8B o 4096 4096 17.3 16.4 −4.8%
Qwen3-8B gate_up 4096 24576 23.4 26.8 +14.4%
Qwen3-8B down 12288 4096 13.7 17.0 +23.8%
Qwen2.5-7B qkv 3584 4608 25.4 29.5 +16.3%
Qwen2.5-7B o 3584 3584 25.3 29.6 +17.1%
Qwen2.5-7B gate_up 3584 37888 24.6 31.1 +26.5%
Qwen2.5-7B down 18944 3584 21.0 30.5 +45.2%

Median ≈ +24% GEMM throughput. The K=4096 square shapes (Qwen3-8B qkv/o)
are at parity — base and branch select the same tile config there. The bf16
asymmetric carrier is −4.8% at the kernel level (do_bench, M=2048).

Testing

  • Numerical correctness (gfx1151): the asymmetric carrier dequant matches a
    reference to <1e-3 relative error in fp16 and is bit-identical in bf16, at
    M=1 and M=2048 across the Qwen3-4B projections; the symmetric fast path is
    numerically consistent with the carrier (bf16 bit-identical, fp16 ≤1 ULP).

  • Perf regression test:

    .venv/bin/python -m pytest tests/kernels/quantization/test_hybrid_w4a16_perf.py -v
    

    Covers fp16/bf16 × symmetric/asymmetric across the Qwen / Gemma / Llama
    prefill shape catalog (incl. added Qwen3-1.7B and Gemma3-4B shapes). The
    prefill (hybrid_triton_w4a16) baselines are regenerated for gfx1151; the
    decode (wvSplitK_int4_g) baselines are unchanged (the kernel is untouched).
    Prefill cells pass within the tolerance band. A tiny-M (M≤4) decode cell can
    intermittently exceed the ±8% band due to run-to-run variance on the unchanged
    HIP decode kernel — not affected by this change.

  • Lint: ruff check, ruff format, and mypy pass (full pre-commit suite
    green).

Not a duplicate

This tunes the existing in-tree HybridW4A16LinearKernel Triton prefill path on
the gfx11 branch and adds the scale/zp dequant carrier. It does not overlap
with the Marlin or CK W4A16 paths.

AI assistance

AI assistance (Claude) was used to author this change. The diff and benchmarks
were reviewed and run by the submitter.


Files changed (3):

  • vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py
  • tests/kernels/quantization/test_hybrid_w4a16_perf.py
  • tests/kernels/quantization/golden/hybrid_w4a16_gfx1151.json

@mgehre-amd mgehre-amd force-pushed the matthias.triton-w4a16-skinny-packedsb branch 3 times, most recently from e6aad07 to b04ceaf Compare June 5, 2026 13:52
@mgehre-amd mgehre-amd changed the title [ROCm][Kernel] W4A16 prefill: optimize [ROCm][Kernel] W4A16 prefill: optimize dequant Jun 5, 2026
@mgehre-amd mgehre-amd force-pushed the matthias.triton-w4a16-skinny-packedsb branch from b04ceaf to d11fb3e Compare June 5, 2026 22:25
Comment thread vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py Outdated
Comment thread vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py Outdated
Comment thread vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py Outdated
@mgehre-amd mgehre-amd force-pushed the matthias.triton-w4a16-skinny-packedsb branch from d11fb3e to aeeea16 Compare June 8, 2026 12:25
… carrier

Speeds up the Triton W4A16 skinny prefill GEMM on gfx11x (RDNA3/3.5). The
prefill dequant inner loop was VALU-issue-bound; this cuts the dequant
instruction count and, for asymmetric layers, folds the per-group scale +
zero-point into a single load. ~14-15% lower TTFT on a W4A16 model with no
accuracy change. Pure Python/Triton; no C++/HIP rebuild.

Changes:
- Packed int4->fp16 dequant (_i4_and_or_magic): one v_and_or_b32 dequants two
  nibbles into fp16 (1024+n) per instruction (the i4_to_half magic trick),
  replacing the scalar v_and_b16/v_or_b16 pair. The kernel selects the packed
  path itself at JIT time -- fp16 AND RDNA gfx11/gfx12, via a
  tl.target_info.constexpr_function (_target_is_gfx1x); no host flag. Everything
  else uses the scalar unpack.
- Distilled gfx11 tile table (_select_skinny_gfx11_config): BLOCK_N=256,
  num_warps=8, BLOCK_K=32, BLOCK_M 64/128 by K and N. Other arches unchanged.
- Asymmetric layers: packed_scale_zp carrier (one fp32 per (n, group)) folds the
  per-group scale and zero-point offset into a single load. fp16 packs
  scale|bias_eff and consumes it with one v_pk_fma_f16; bf16 packs scale|zp_int
  and keeps the int-domain subtract (RDNA3 has no v_pk_fma_bf16). Materialised
  only for asym layers; the kernel's HAS_ZP constexpr selects carrier vs scales.
- Symmetric layers keep a dedicated fast path (scales + constant -8 offset): sym
  has no second load to fold, so the carrier would be pure overhead (measured
  ~+8% on fp16 sym, up to +22% on o_proj).
- perf test: exercises the carrier for asym providers, adds Qwen3-1.7B and
  Gemma3-4B prefill shapes; gfx1151 golden regenerated.

Measured on gfx1151 (Qwen3-4B-AWQ, fp16, input 3968 / output 1): TTFT
1436 -> 1234 ms (-14.1%). bf16 asym carrier -4.8% (do_bench). Carrier dequant
verified numerically (fp16 <1e-3 rel error; bf16 bit-identical).

Verified: tests/kernels/quantization/test_hybrid_w4a16_perf.py 111/112 pass --
the one failure is a pre-existing flaky tiny-M wvSplitK_int4 decode cell
(unchanged HIP kernel), unrelated to this change. ruff + mypy clean.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd force-pushed the matthias.triton-w4a16-skinny-packedsb branch from aeeea16 to f9dcc75 Compare June 8, 2026 13:23
@mgehre-amd mgehre-amd marked this pull request as ready for review June 8, 2026 15:56
@mgehre-amd mgehre-amd requested review from marcusr-amd and removed request for AndreasKaratzas June 8, 2026 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant