Skip to content

wvSplitK int4: pad weight K-stride by +128 B on gfx1151#989

Open
mgehre-amd wants to merge 2 commits into
gfx11from
matthias.int4-stride-pad-experiment
Open

wvSplitK int4: pad weight K-stride by +128 B on gfx1151#989
mgehre-amd wants to merge 2 commits into
gfx11from
matthias.int4-stride-pad-experiment

Conversation

@mgehre-amd

@mgehre-amd mgehre-amd commented Jun 8, 2026

Copy link
Copy Markdown

When innermost dimension of the weights matrix is a multiple of 2048 bytes (4096 int4 elements), pad by 128 bytes to avoid a DDR bandwidth cliff.

The padding is performed on the python side. The HIP kernel gets a new runtime argument for the stride to skip over the padding. This keeps computations the same.

The effect only exists on gfx1151 (either due to MALL or due to multiple DDR channels), but not on gfx1150.
All PMC counters are the same with and without the padding change, except the total number of cycles is reduced.
In ATT, we see that the stall cycles on the wait for DDR reads reduce.

Shape table is (out_features, in_features) = (M, K). Padded entries are those where K_packed = K/2 is a multiple of 2048 B (i.e. K % 4096 == 0) — the gating condition in process_weights_after_loading.

Shape (M×K) Role unpad GiB/s padded GiB/s Δ%
4096×4096 Llama-8B q/o, Qwen3-8B o 153.1 170.3 +11.4
6144×4096 Llama-8B qkv, Qwen3-8B qkv 154.4 183.3 +18.5
28672×4096 Llama-8B gate_up 195.1 195.7 +0.1
24576×4096 Qwen3-8B gate_up 193.4 194.1 +0.3
11008×4096 Llama2-7B up/gate 177.6 182.3 +2.7
22016×4096 Llama2-7B gate_up 189.1 192.4 +2.1
2048×4096 Qwen3.5-A3B GDN out_proj 129.1 149.4 +16.0
2048×8192 SmolLM2 down 103.4 163.9 +57.5
512×8192 L2-boundary perf-test 65.1 120.3 +83.5
4096×12288 Qwen3-8B down 174.0 178.0 +2.4
2048×16384 Gemma-2B down 135.9 164.3 +20.9
6144×2560 Qwen3-4B qkv (control, K_packed%2048≠0) 168.0 165.2 −1.1
4608×3584 Qwen7B qkv (control, K_packed%2048≠0) 172.6 170.8 −1.0

This change also helps TTFT (prefill kernel hits the same DDR issue). On SmolLM2-1.7B-AWQ, --model trymirai/SmolLM2-1.7B-Instruct-AWQ --num-prompts 10 --max-model-len 256 --input-len 128 --output-len 128 --dtype float16 --target-gpu-memory-gb 10 --max-num-seqs 1 -e TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE -e TORCH_BLAS_PREFER_HIPBLASLT=1,
we get

Metric this PR before improvement
Median TPOT (ms) 6.66 7.30 −8.8%
Median TTFT (ms) 73.42 83.32 −11.9%

@mgehre-amd mgehre-amd force-pushed the matthias.int4-stride-pad-experiment branch 8 times, most recently from 6f5a1ff to edb885b Compare June 8, 2026 10:46
@mgehre-amd mgehre-amd changed the title wvSplitK int4: pad weight K-stride by +128 B on gfx1151 K%4096 cliff wvSplitK int4: pad weight K-stride by +128 B on gfx1151 Jun 8, 2026
@mgehre-amd mgehre-amd marked this pull request as ready for review June 8, 2026 13:46
@mgehre-amd mgehre-amd requested a review from dllehr-amd as a code owner June 8, 2026 13:46
@mgehre-amd mgehre-amd requested review from eble-amd and marcusr-amd and removed request for dllehr-amd June 8, 2026 13:46
Comment thread csrc/rocm/skinny_gemms_int4.cu Outdated
Comment thread csrc/rocm/skinny_gemms_int4.cu Outdated
Comment thread csrc/rocm/skinny_gemms_int4_kernels.cuh Outdated
Wires the B_row_stride_bytes plumbing all the way from the Python weight
loader to the int4 wvSplitK kernel so the production decode path benefits
from the gfx1151 K_packed%2048 BW cliff workaround.

Kernel layer:
- wvSplitK_int4_g now derives the per-row byte stride from in_a.stride(0)
  instead of computing K_in/2 inline (drops the strict M*K/2 size check).
  The int64 stride is bounds-checked via std::in_range<int> at the
  wrapper and narrowed once into b_row_stride_bytes_i32, so the launch
  macros pass an int directly with no further cast.
- WVSPLITK_INT4G_LAUNCH_W_AC macro passes b_row_stride_bytes through to
  the kernel templates already updated in the previous commit.

Triton layer:
- _triton_w4a16_skinny_fmt_kernel takes a stride_bn argument and uses it
  in b_ptrs instead of hardcoding K8.  The Python launcher reads
  b_q.stride(0) and passes it.  Weight layout is unchanged when stride(0)
  == K//8 (contiguous), so this is a no-op outside the padded layout.

Python layer (process_weights_after_loading):
- When K_packed = K/2 lands on a 2048 B multiple (i.e. K % 4096 == 0 --
  Llama-8B q/o, SmolLM2 down_proj, etc.) AND running on gfx1151,
  allocate the packed-weight buffer with 32 extra trailing int32 columns
  per row (+128 B = one cache line).  Both the int8 (skinny) and int32
  (triton) views of the shared buffer inherit the padded stride; the
  ~3% weight-memory overhead is only paid on the affected layers.

Gating: gated to gfx1151 only via a new on_gfx1151() helper.  Empirical
data on gfx1150 (Strix Point HX 370, 8 CUs, 128-bit DDR) shows no
measurable difference between contiguous and padded weights on the
cliff shape (N=2048, K=8192) over 3 runs, ratio 0.995-1.008.  On gfx1151
(Strix Halo, 20 CUs, 256-bit DDR + 32 MB MALL) the same shape lifts
from 153.8 to 206.2 GiB/s (+34%).  The exact mechanism is a
memory-subsystem hash collision somewhere between L2 and DRAM --
per-L2-slice GL2C counters are byte-identical between variants, only
SQ_WAVE_CYCLES grows (+14%) with ATT pinning the extra cycles to
s_waitcnt vmcnt drains.  Whether the collision is at DRAM channel-select
bits or MALL slice/bank bits is not directly distinguishable with the
counters exposed on gfx1151 (rocprofv3 doesn't expose UMC or MALL
counters), but either way the mitigation -- offsetting the stride by
one cache line -- is the same, and the gfx1150 data point confirms the
fix is correctly gated.

Non-cliff shapes (Qwen3-1.7B q/o/gate_up, SmolLM2 QKV, etc.) skip the
padding branch entirely and remain bit-identical to before.

The padded layout interacts with the asymmetric-quantization Triton
path and the bf16 zp fast path that were tuned against the contiguous
layout; broader correctness testing is needed before this lands on the
production branch.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd force-pushed the matthias.int4-stride-pad-experiment branch from edb885b to 3638754 Compare June 8, 2026 22:02
The wvSplitK family historically names its parameters `in_a` (weight)
and `in_b` (activation), which is the reverse of the conventional GEMM
A/B naming.  Reviewer-confused: addresses #989 review comment.

Renames the three int4 ops affected by this PR (production + 2 sweep
variants) plus their declarations, schema, and Python register_fake:
- wvSplitK_int4_g
- wvSplitK_int4g_sweep
- wvSplitK_int4g_hf_sweep

Mechanical replacement; no behaviour change.  The sibling wvSplitK ops
(bf16, int8, w8a8, fused-silu-mul, etc.) still use `in_a/in_b` and are
left as-is to keep the rename scoped to the ops this PR already touches.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd requested a review from eble-amd June 8, 2026 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants