wvSplitK int4: pad weight K-stride by +128 B on gfx1151#989
Open
mgehre-amd wants to merge 2 commits into
Open
Conversation
6f5a1ff to
edb885b
Compare
eble-amd
reviewed
Jun 8, 2026
Wires the B_row_stride_bytes plumbing all the way from the Python weight loader to the int4 wvSplitK kernel so the production decode path benefits from the gfx1151 K_packed%2048 BW cliff workaround. Kernel layer: - wvSplitK_int4_g now derives the per-row byte stride from in_a.stride(0) instead of computing K_in/2 inline (drops the strict M*K/2 size check). The int64 stride is bounds-checked via std::in_range<int> at the wrapper and narrowed once into b_row_stride_bytes_i32, so the launch macros pass an int directly with no further cast. - WVSPLITK_INT4G_LAUNCH_W_AC macro passes b_row_stride_bytes through to the kernel templates already updated in the previous commit. Triton layer: - _triton_w4a16_skinny_fmt_kernel takes a stride_bn argument and uses it in b_ptrs instead of hardcoding K8. The Python launcher reads b_q.stride(0) and passes it. Weight layout is unchanged when stride(0) == K//8 (contiguous), so this is a no-op outside the padded layout. Python layer (process_weights_after_loading): - When K_packed = K/2 lands on a 2048 B multiple (i.e. K % 4096 == 0 -- Llama-8B q/o, SmolLM2 down_proj, etc.) AND running on gfx1151, allocate the packed-weight buffer with 32 extra trailing int32 columns per row (+128 B = one cache line). Both the int8 (skinny) and int32 (triton) views of the shared buffer inherit the padded stride; the ~3% weight-memory overhead is only paid on the affected layers. Gating: gated to gfx1151 only via a new on_gfx1151() helper. Empirical data on gfx1150 (Strix Point HX 370, 8 CUs, 128-bit DDR) shows no measurable difference between contiguous and padded weights on the cliff shape (N=2048, K=8192) over 3 runs, ratio 0.995-1.008. On gfx1151 (Strix Halo, 20 CUs, 256-bit DDR + 32 MB MALL) the same shape lifts from 153.8 to 206.2 GiB/s (+34%). The exact mechanism is a memory-subsystem hash collision somewhere between L2 and DRAM -- per-L2-slice GL2C counters are byte-identical between variants, only SQ_WAVE_CYCLES grows (+14%) with ATT pinning the extra cycles to s_waitcnt vmcnt drains. Whether the collision is at DRAM channel-select bits or MALL slice/bank bits is not directly distinguishable with the counters exposed on gfx1151 (rocprofv3 doesn't expose UMC or MALL counters), but either way the mitigation -- offsetting the stride by one cache line -- is the same, and the gfx1150 data point confirms the fix is correctly gated. Non-cliff shapes (Qwen3-1.7B q/o/gate_up, SmolLM2 QKV, etc.) skip the padding branch entirely and remain bit-identical to before. The padded layout interacts with the asymmetric-quantization Triton path and the bf16 zp fast path that were tuned against the contiguous layout; broader correctness testing is needed before this lands on the production branch. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
edb885b to
3638754
Compare
The wvSplitK family historically names its parameters `in_a` (weight) and `in_b` (activation), which is the reverse of the conventional GEMM A/B naming. Reviewer-confused: addresses #989 review comment. Renames the three int4 ops affected by this PR (production + 2 sweep variants) plus their declarations, schema, and Python register_fake: - wvSplitK_int4_g - wvSplitK_int4g_sweep - wvSplitK_int4g_hf_sweep Mechanical replacement; no behaviour change. The sibling wvSplitK ops (bf16, int8, w8a8, fused-silu-mul, etc.) still use `in_a/in_b` and are left as-is to keep the rename scoped to the ops this PR already touches. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When innermost dimension of the weights matrix is a multiple of 2048 bytes (4096 int4 elements), pad by 128 bytes to avoid a DDR bandwidth cliff.
The padding is performed on the python side. The HIP kernel gets a new runtime argument for the stride to skip over the padding. This keeps computations the same.
The effect only exists on gfx1151 (either due to MALL or due to multiple DDR channels), but not on gfx1150.
All PMC counters are the same with and without the padding change, except the total number of cycles is reduced.
In ATT, we see that the stall cycles on the wait for DDR reads reduce.
Shape table is
(out_features, in_features)=(M, K). Padded entries are those whereK_packed = K/2is a multiple of 2048 B (i.e.K % 4096 == 0) — the gating condition inprocess_weights_after_loading.This change also helps TTFT (prefill kernel hits the same DDR issue). On SmolLM2-1.7B-AWQ,
--model trymirai/SmolLM2-1.7B-Instruct-AWQ --num-prompts 10 --max-model-len 256 --input-len 128 --output-len 128 --dtype float16 --target-gpu-memory-gb 10 --max-num-seqs 1 -e TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE -e TORCH_BLAS_PREFER_HIPBLASLT=1,we get