Skip to content

Support padded K for a8w8 bpreshuffle GEMM#3611

Open
ThomasNing wants to merge 1 commit into
ROCm:mainfrom
ThomasNing:thomas/a8w8-bpreshuffle-pad-k
Open

Support padded K for a8w8 bpreshuffle GEMM#3611
ThomasNing wants to merge 1 commit into
ROCm:mainfrom
ThomasNing:thomas/a8w8-bpreshuffle-pad-k

Conversation

@ThomasNing
Copy link
Copy Markdown

What

Add an opt-in K-padding path for a8w8 bpreshuffle GEMM:

  • shuffle_weight(..., pad_k_to=...) pads the weight tensor's last dimension before preshuffle.
  • gemm_a8w8_bpreshuffle accepts padded preshuffled weights by zero-padding XQ to WQ.shape[-1] before config lookup and kernel dispatch.
  • gemm_a8w8_bpreshuffle rejects WQ.shape[-1] < XQ.shape[-1] with a clear error.

The default behavior is unchanged because pad_k_to=0.

Why

SGLang compressed-tensors FP8 loading can produce valid a8w8 GEMM shapes whose K dimension is not aligned for bpreshuffle/tuned CKTile dispatch. One real GLM-4.5-Air FP8 shape on MI300X is K=10944; padding it to 11008 allows the preshuffled path to run, but the dynamic activation must be padded consistently at GEMM dispatch time.

Downstream can work around this in SGLang, but the robust behavior belongs in AITER because AITER owns both shuffle_weight and gemm_a8w8_bpreshuffle.

Tests

Added op_tests/test_gemm_a8w8_bpreshuffle_pad_k.py covering:

  • shuffle_weight(..., pad_k_to=128) pads the last dim and records original/padded K metadata.
  • gemm_a8w8_bpreshuffle pads XQ before backend dispatch when WQ is wider.
  • gemm_a8w8_bpreshuffle rejects a shorter WQ.

Local syntax check:

python3 -m py_compile aiter/ops/shuffle.py aiter/ops/gemm_op_a8w8.py op_tests/test_gemm_a8w8_bpreshuffle_pad_k.py

Related

Mixed-dtype GQA attention support request for CK Tile: ROCm/composable_kernel#3744

@ThomasNing ThomasNing requested a review from a team June 8, 2026 10:49
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 8, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3611 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant