From 0e3d214bd11d8c75934037043e47cd151ba98491 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 21 May 2026 05:59:03 -0600 Subject: [PATCH 1/3] [kernels] hybrid_w4a16_moe: add group_size<=64 tune for Qwen3-Omni The existing Triton prefill MoE tune (commits 14901427fb + 8af2e37b0c) was derived from Qwen3.5-A3B's shape (group_size=128, E=256). At small group_size the kernel wrapper caps BLOCK_K to group_size, so the narrow-BN/small-BM strategy that wins for group_size=128 becomes severely under-occupied -- regressing Qwen3-Omni-30B-A3B-AWQ-4bit (group_size=32, E=128) TTFT by ~20% on Strix Halo (gfx1151). Add a second tuned branch in _triton_config / _select_block_size_m gated by `self._group_size <= 64` (a hard cutoff between the only two production group_sizes we see, 32 and 128). The group_size > 64 path is byte-identical to before, so the existing Qwen3.5-A3B 2x speedup is preserved. Configs derived from a kernel-only sweep at Qwen3-Omni shape (Strix Halo gfx1151, M=2048 N=768 K=2048 E=128 top_k=8 group_size=32) via the new benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py tool: gemm1 (K=2048 N=1536): BM=128 BN=64 GM=1 nw=8 ns=1 -> 6.79 ms vs 11.27 ms at the Qwen3.5 tune (1.66x). gemm2 (K=768 N=2048): BM=64 BN=64 GM=1 nw=4 ns=1 -> 2.58 ms vs 3.18 ms at the Qwen3.5 tune (1.23x). Alignment block_size_m = lcm(128, 64) = 128 (the new TRITON_BLOCK_SIZE_M_SMALL_GS); gemm2 uses the existing _expert_ids_for repeat_interleave path so each 64-row sub-block sees the right expert id. This is the first caller to actually exercise BLOCK_M != alignment -- the infrastructure was already in place from 8af2e37b0c. End-to-end vLLM serving TTFT on cyankiwi/Qwen3-Omni-30B-A3B-Instruct- AWQ-4bit (--num-prompts 10 --input-len 4096 --output-len 1 --max-num-seqs 1): before this patch: 2228 ms (bad baseline, current gfx11 tip) after this patch: 1867 ms (-16%) Changes: - bench_hybrid_w4a16_moe.py's per-call timings are dominated by host setup (weight quant, sort, _resize_cache), so it can't resolve the per-config differences this tune relies on. The new sweep tool calls invoke_fused_moe_kernel_hybrid_triton directly to time only the kernel. - Bumped atol 2e-2 -> 3e-2 on test_hybrid_w4a16_moe_force_triton. The new tune's larger BM and num_warps change the partial-sum reduction order, which pushes the n=k=256 stress shape's max abs diff to ~0.027. Real-model TTFT precision is unaffected (large K averages out per-tile rounding). Signed-off-by: Matthias Gehre --- .../kernels/sweep_hybrid_w4a16_moe_triton.py | 309 ++++++++++++++++++ tests/kernels/moe/test_hybrid_w4a16_moe.py | 3 +- .../layers/fused_moe/hybrid_w4a16_moe.py | 43 ++- 3 files changed, 348 insertions(+), 7 deletions(-) create mode 100644 benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py diff --git a/benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py b/benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py new file mode 100644 index 000000000000..5cbb7deebab5 --- /dev/null +++ b/benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Kernel-only sweep for the hybrid_w4a16 MoE prefill Triton kernel. + +Calls ``invoke_fused_moe_kernel_hybrid_triton`` directly (skipping the +HybridW4A16MoEExperts.apply host setup so kernel time is not drowned +out, unlike benchmarks/kernels/bench_hybrid_w4a16_moe.py). Sweeps the +full Triton config space and ranks results per gemm. + +Default shapes target Qwen3-Omni-30B-A3B (E=128, group_size=32, +hidden=2048, moe_intermediate=768) so the tune for this model can be +re-derived from scratch. Override with --n / --k / --e / --group-size +to retune for other shapes (e.g. --n 512 --k 2048 --e 256 +--group-size 128 for Qwen3.5-A3B). + +Usage: + python benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py + python benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py \\ + --m 2048 --n 512 --k 2048 --e 256 --topk 8 --group-size 128 \\ + --csv /scratch/mgehre/tmp/sweep_qwen35.csv +""" + +from __future__ import annotations + +import argparse +import csv +import itertools +import sys +from pathlib import Path + +import torch + +REPO = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO)) + +from tests.kernels.moe.test_hybrid_w4a16_moe import ( # noqa: E402 + _make_hybrid_moe_weights, +) +from vllm.model_executor.layers.fused_moe import fused_topk # noqa: E402 +from vllm.model_executor.layers.fused_moe.fused_moe import ( # noqa: E402 + invoke_fused_moe_kernel_hybrid_triton, +) +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( # noqa: E402 + moe_align_block_size, +) +from vllm.triton_utils import tl, triton # noqa: E402 + + +def build_inputs( + M: int, + N_inter: int, + K_hidden: int, + E: int, + topk: int, + group_size: int, + dtype: torch.dtype = torch.float16, + device: str = "cuda", +): + """Build weights, hidden, and topk_ids once for the whole sweep.""" + w1, w1_s, _ = _make_hybrid_moe_weights(E, K_hidden, 2 * N_inter, group_size, device) + w2, w2_s, _ = _make_hybrid_moe_weights(E, N_inter, K_hidden, group_size, device) + hidden = torch.randn(M, K_hidden, device=device, dtype=dtype) / 10.0 + scores = torch.randn(M, E, device=device, dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(hidden, scores, topk, False) + return dict( + w1=w1, + w1_s=w1_s, + w2=w2, + w2_s=w2_s, + hidden=hidden, + topk_ids=topk_ids, + topk_weights=topk_weights, + ) + + +def time_us(fn, warmup_ms: int = 25, rep_ms: int = 80) -> float: + return ( + triton.testing.do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="median") + * 1000.0 + ) + + +def make_gemm1_fn(inputs, cfg, group_size, E): + """Closure that runs only gemm1 (hidden -> 2*N_inter).""" + block_m = cfg["BLOCK_SIZE_M"] + sorted_ids, expert_ids, npp = moe_align_block_size( + inputs["topk_ids"], block_m, E, None, ignore_invalid_experts=True + ) + num_slots = sorted_ids.size(0) + N_w = inputs["w1"].size(1) # 2*N_inter + out = torch.empty( + num_slots, N_w, device=inputs["hidden"].device, dtype=inputs["hidden"].dtype + ) + compute_type = ( + tl.float16 if inputs["hidden"].dtype == torch.float16 else tl.bfloat16 + ) + + def run(): + invoke_fused_moe_kernel_hybrid_triton( + A=inputs["hidden"], + B=inputs["w1"], + C=out, + B_scale=inputs["w1_s"], + topk_weights=None, + sorted_token_ids=sorted_ids, + expert_ids=expert_ids, + num_tokens_post_padded=npp, + mul_routed_weight=False, + top_k=inputs["topk_ids"].size(1), + config=cfg, + compute_type=compute_type, + group_size=group_size, + ) + + return run + + +def make_gemm2_fn(inputs, cfg, group_size, E): + """Closure that runs only gemm2 (N_inter -> hidden). + + Mimics apply()'s second call: A is in slot-space (one row per slot + with the post-activation activations), B is w2, top_k is 1 so the + kernel reads A[slot] directly. + """ + block_m = cfg["BLOCK_SIZE_M"] + sorted_ids, expert_ids, npp = moe_align_block_size( + inputs["topk_ids"], block_m, E, None, ignore_invalid_experts=True + ) + num_slots = sorted_ids.size(0) + # w2: [E, N=K_hidden, K_in=N_inter//8 (int32, holds N_inter//8 int4 elems each)] + K_in = inputs["w2"].size(2) * 8 + K_hidden = inputs["w2"].size(1) + act = ( + torch.randn( + num_slots, + K_in, + device=inputs["hidden"].device, + dtype=inputs["hidden"].dtype, + ) + / 10.0 + ) + out = torch.empty( + num_slots, + K_hidden, + device=inputs["hidden"].device, + dtype=inputs["hidden"].dtype, + ) + compute_type = ( + tl.float16 if inputs["hidden"].dtype == torch.float16 else tl.bfloat16 + ) + + def run(): + invoke_fused_moe_kernel_hybrid_triton( + A=act, + B=inputs["w2"], + C=out, + B_scale=inputs["w2_s"], + topk_weights=None, + sorted_token_ids=sorted_ids, + expert_ids=expert_ids, + num_tokens_post_padded=npp, + mul_routed_weight=False, + top_k=1, + config=cfg, + compute_type=compute_type, + group_size=group_size, + ) + + return run + + +def sweep(args): + inputs = build_inputs(args.m, args.n, args.k, args.e, args.topk, args.group_size) + torch.accelerator.synchronize() + + # Each (BM, BN, BK, GM, nw, ns) candidate. BK is capped to + # group_size inside the wrapper, so listing values > group_size is + # equivalent to BK=group_size. + bm_list = [int(x) for x in args.block_m] + bn_list = [int(x) for x in args.block_n] + bk_list = [int(x) for x in args.block_k] + gm_list = [int(x) for x in args.group_m] + nw_list = [int(x) for x in args.num_warps] + ns_list = [int(x) for x in args.num_stages] + + gemms: list[tuple[str, callable]] = [] + if "gemm1" in args.gemms: + gemms.append(("gemm1", make_gemm1_fn)) + if "gemm2" in args.gemms: + gemms.append(("gemm2", make_gemm2_fn)) + + rows: list[dict] = [] + + for gname, maker in gemms: + print( + f"\n=== {gname} sweep " + f"(M={args.m}, N={args.n}, K={args.k}, E={args.e}, " + f"topk={args.topk}, group_size={args.group_size}) ===" + ) + results: list[tuple[dict, float]] = [] + for bm, bn, bk, gm, nw, ns in itertools.product( + bm_list, bn_list, bk_list, gm_list, nw_list, ns_list + ): + cfg = dict( + BLOCK_SIZE_M=bm, + BLOCK_SIZE_N=bn, + BLOCK_SIZE_K=bk, + GROUP_SIZE_M=gm, + num_warps=nw, + num_stages=ns, + ) + try: + fn = maker(inputs, cfg, args.group_size, args.e) + fn() # warmup + correctness sanity + torch.accelerator.synchronize() + t = time_us(fn) + except Exception as e: + print(f" SKIP {cfg}: {type(e).__name__}: {e}") + continue + results.append((cfg, t)) + rows.append({"gemm": gname, **cfg, "us": t}) + + results.sort(key=lambda x: x[1]) + ref_cfg, ref_t = results[0] + print( + f" {'rank':>4} {'BM':>4} {'BN':>4} {'BK':>4} {'GM':>4} " + f"{'nw':>3} {'ns':>3} {'us':>9} {'vs best':>8}" + ) + for i, (cfg, t) in enumerate(results): + mark = "*" if i == 0 else " " + print( + f" {i + 1:>4}{mark} {cfg['BLOCK_SIZE_M']:>4} " + f"{cfg['BLOCK_SIZE_N']:>4} {cfg['BLOCK_SIZE_K']:>4} " + f"{cfg['GROUP_SIZE_M']:>4} {cfg['num_warps']:>3} " + f"{cfg['num_stages']:>3} {t:>9.1f} {t / ref_t:>7.2f}x" + ) + + if args.csv: + with open(args.csv, "w", newline="") as f: + w = csv.DictWriter( + f, + fieldnames=[ + "gemm", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "num_warps", + "num_stages", + "us", + ], + ) + w.writeheader() + w.writerows(rows) + print(f"\nCSV: {args.csv}") + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + # Default = Qwen3-Omni-30B-A3B thinker shape. + p.add_argument("--m", type=int, default=2048, help="Prefill batch (rows)") + p.add_argument( + "--n", + type=int, + default=768, + help="moe_intermediate_size (single, not 2x). 768 for Qwen3-Omni, " + "512 for Qwen3.5-A3B.", + ) + p.add_argument("--k", type=int, default=2048, help="hidden_size") + p.add_argument( + "--e", + type=int, + default=128, + help="num_experts. 128 for Qwen3-Omni, 256 for Qwen3.5-A3B.", + ) + p.add_argument("--topk", type=int, default=8, help="experts per token") + p.add_argument( + "--group-size", + type=int, + default=32, + help="W4A16 group size. 32 for Qwen3-Omni, 128 for Qwen3.5-A3B.", + ) + p.add_argument( + "--gemms", + nargs="+", + default=["gemm1", "gemm2"], + choices=["gemm1", "gemm2"], + ) + p.add_argument("--block-m", nargs="+", default=[16, 32, 64, 128]) + p.add_argument("--block-n", nargs="+", default=[16, 32, 64, 128]) + p.add_argument( + "--block-k", + nargs="+", + default=[32, 64, 128], + help="Capped to group_size inside the wrapper", + ) + p.add_argument("--group-m", nargs="+", default=[1, 4, 8]) + p.add_argument("--num-warps", nargs="+", default=[2, 4, 8]) + p.add_argument("--num-stages", nargs="+", default=[1, 2]) + p.add_argument( + "--csv", type=str, default=None, help="Write all results to this CSV path" + ) + args = p.parse_args() + sweep(args) + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/moe/test_hybrid_w4a16_moe.py b/tests/kernels/moe/test_hybrid_w4a16_moe.py index 77f2b8fe4352..4fdbc0741376 100644 --- a/tests/kernels/moe/test_hybrid_w4a16_moe.py +++ b/tests/kernels/moe/test_hybrid_w4a16_moe.py @@ -256,7 +256,8 @@ def test_hybrid_w4a16_moe_force_triton( hybrid_out, torch_output = _run_hybrid_moe( m, n, k, e, topk, group_size, force_triton=True ) - torch.testing.assert_close(hybrid_out, torch_output, atol=2e-2, rtol=0) + # gs<=64 _triton_config branch reorders fp16 reductions; needs 3e-2. + torch.testing.assert_close(hybrid_out, torch_output, atol=3e-2, rtol=0) @pytest.mark.skipif( diff --git a/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py index 6f14f0a9e722..799db13521aa 100644 --- a/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py +++ b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py @@ -131,16 +131,26 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # neutral (not measured in this PR). TRITON_BLOCK_SIZE_M = 32 - @staticmethod - def _select_block_size_m(num_tokens: int, topk: int, E: int) -> int: + # Alignment for the group_size <= SMALL_GROUP_SIZE_THRESHOLD branch + # of _triton_config. Must be lcm of the BLOCK_SIZE_M values that + # branch emits (128 for gemm1, 64 for gemm2). + TRITON_BLOCK_SIZE_M_SMALL_GS = 128 + + # group_size <= THRESHOLD uses the small-gs _triton_config branch; + # > THRESHOLD uses the default. Boundary between 32 and 128. + SMALL_GROUP_SIZE_THRESHOLD = 64 + + def _select_block_size_m(self, num_tokens: int, topk: int, E: int) -> int: """Select block size in the M dimension. - Decode (num_tokens <= MAX_SKINNY_BATCH_SIZE): use small block sizes - compatible with the wvSplitK_int4 HIP kernel (N=1..5). - Prefill (num_tokens > MAX_SKINNY_BATCH_SIZE): use the Triton kernel's - BLOCK_SIZE_M for efficient batched GEMM. + Decode (num_tokens <= MAX_SKINNY_BATCH_SIZE): small block sizes + compatible with the wvSplitK_int4 HIP kernel. + Prefill: TRITON_BLOCK_SIZE_M, or TRITON_BLOCK_SIZE_M_SMALL_GS + for small group_size. """ if num_tokens > HybridW4A16MoEExperts.MAX_SKINNY_BATCH_SIZE: + if self._group_size <= HybridW4A16MoEExperts.SMALL_GROUP_SIZE_THRESHOLD: + return HybridW4A16MoEExperts.TRITON_BLOCK_SIZE_M_SMALL_GS return HybridW4A16MoEExperts.TRITON_BLOCK_SIZE_M if num_tokens > 1: avg = num_tokens * topk / E @@ -213,6 +223,27 @@ def _triton_config( BLOCK_SIZE_K = self._group_size # = 128 for the Qwen3.5-A3B path assert BLOCK_SIZE_K % 8 == 0 + if self._group_size <= HybridW4A16MoEExperts.SMALL_GROUP_SIZE_THRESHOLD: + # gemm2 BM=64 < alignment=128; apply()'s _expert_ids_for + # repeat_interleaves expert_ids to compensate. + if K < 1024: + return dict( + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + num_warps=4, + num_stages=1, + ) + return dict( + BLOCK_SIZE_M=128, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + num_warps=8, + num_stages=1, + ) + # Per-shape sweep on Strix Halo (gfx1151) at BLOCK_M=32 (the # current alignment), using benchmarks/kernels/ # sweep_int4g_moe_kernel.py + a per-shape direct call to From b7c9d59a2d5904ad41ed840790c01d7eef703c15 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 21 May 2026 14:42:08 -0600 Subject: [PATCH 2/3] [kernels] hybrid_w4a16_moe: allow BLOCK_K > group_size The hybrid Triton kernel's shuffle_w4a16 path loaded one scale per K-tile, which forced ``invoke_fused_moe_kernel_hybrid_triton`` to cap BLOCK_SIZE_K at group_size. At group_size=32 (Qwen3-Omni) this means 64 inner-loop iterations for K=2048 -- many short matmuls instead of a few large ones. Add a constexpr-gated multi-scale path to the kernel: when BLOCK_SIZE_K > group_size, load a per-K-row scale tensor [BLOCK_K, BLOCK_N] (mirrors what the non-shuffle wna16 path already does); otherwise keep the original [BLOCK_N] one-scale-per-tile path unchanged. The wrapper's cap is replaced with an assertion that BK either divides group_size or is a multiple of it. Net effect at the Qwen3-Omni shape on Strix Halo: BK=32 still wins the per-gemm sweep (register pressure caps BK gains; this kernel is already near peak at small BK on gfx1151), so the production _triton_config selection is unchanged. The lift unblocks future tuning at other (K, group_size) combinations. Verified: tests/kernels/moe/test_hybrid_w4a16_moe.py 70/70 pass. Signed-off-by: Matthias Gehre --- .../layers/fused_moe/fused_moe.py | 47 ++++++++++++------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 73eff820d207..088d5593fd4d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -279,17 +279,29 @@ def fused_moe_kernel_gptq_awq( b_nk = (b_exp >> exl_shifts) & 0xF # [BLOCK_N, BLOCK_K] b = tl.trans(b_nk) # [BLOCK_K, BLOCK_N] - # Scales: [E, N, K//G] — load per-group scale for this K tile - g_idx = (k * BLOCK_SIZE_K) // group_size - b_scale_ptrs = ( - b_scale_ptr - + off_experts * stride_bse - + offs_bn * stride_bsn - + g_idx * stride_bsk - ) - b_scale = tl.load(b_scale_ptrs).to(tl.float32) - # Dequant: (nibble - 8) * scale - b = ((b.to(tl.float32) - b_zp_num) * b_scale[None, :]).to(compute_type) + # Scales: [E, N, K//G]. Two constexpr paths — Triton does not + # broadcast a degenerate [BLOCK_K, BLOCK_N] load back to a + # scalar reliably, so BLOCK_K <= group_size keeps its own + # [BLOCK_N] load (one scale per tile) for full speed. + if group_size >= BLOCK_SIZE_K: + g_idx = (k * BLOCK_SIZE_K) // group_size + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn * stride_bsn + + g_idx * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs).to(tl.float32) + b = ((b.to(tl.float32) - b_zp_num) * b_scale[None, :]).to(compute_type) + else: + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs).to(tl.float32) # [BLOCK_K, BLOCK_N] + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) else: b = tl.load(b_ptrs) if use_int4_w4a16: @@ -817,12 +829,15 @@ def invoke_fused_moe_kernel_hybrid_triton( config = config.copy() config["SPLIT_K"] = 1 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") - # BLOCK_K must be multiple of 8 for ExLlama shuffle interleave + # BLOCK_K must be multiple of 8 for ExLlama shuffle interleave. assert BLOCK_SIZE_K % 8 == 0 - # BLOCK_K must not exceed group_size (one scale per K-tile) - BLOCK_SIZE_K = min(BLOCK_SIZE_K, group_size) - assert BLOCK_SIZE_K % 8 == 0, ( - f"group_size {group_size} must be a multiple of 8 for shuffle kernel" + # BLOCK_K must tile cleanly across the per-group scale layout: + # either BK <= group_size (multiple tiles share one scale) or + # BK >= group_size with BK % group_size == 0 (each tile spans an + # integer number of groups; scales loaded per K-row inside the kernel). + assert group_size >= BLOCK_SIZE_K or BLOCK_SIZE_K % group_size == 0, ( + f"BLOCK_SIZE_K ({BLOCK_SIZE_K}) must be <= group_size ({group_size}) " + f"or a multiple of group_size" ) with record_function_or_nullcontext( From 84660b1119506e96327a628b2591a725d72dbe43 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 26 May 2026 03:52:45 -0600 Subject: [PATCH 3/3] [kernels] hybrid_w4a16_moe: add experimental grouped-GEMM prefill path The production path uses moe_align_block_size + sorted_token_ids with the kernel doing a virtual gather of A. This pads num_slots up to M*top_k + E*(BM-1), launching ~50% padded blocks that early-exit but still consume kernel-launch slots. Add an alternative apply path that uses the existing moe_permute / moe_unpermute ops (already supported by cutlass_moe + exllama paths) to lay activations out in expert-contiguous order, and a new Triton kernel (fused_moe_kernel_hybrid_w4a16_grouped) that reads them contiguously, indexed by a per-block (expert_id, m_start, m_count) table. No padding, no expert_ids[block]==-1 sentinel, no virtual gather. End-to-end TTFT on cyankiwi/Qwen3-Omni-30B-A3B-Instruct-AWQ-4bit (--num-prompts 10 --input-len 4096 --output-len 1 --max-num-seqs 1 on Strix Halo gfx1151, back-to-back same-session A/B): grouped OFF (production path): 1867 ms grouped ON: 1766 ms (-5.5%) Correctness checked vs the production path at 4 shapes (Qwen3-Omni m=128 and m=2048 with group_size=32, Qwen3.5-A3B m=2048 with group_size=128, and a tiny edge case m=16 n=k=256 group_size=32 with force_triton): grouped vs padded max abs diff = 0.0 in every case (bit-identical). Changes: - Gated behind VLLM_HYBRID_W4A16_GROUPED=1 (off by default) for the initial landing; only triggers for prefill (num_tokens > 5) and only when expert_map is None and apply_router_weight_on_input is False. Other config combinations fall through to the existing path. - The new grouped kernel reuses the constexpr-gated multi-scale-per- tile path from fused_moe_kernel_gptq_awq (so BLOCK_K > group_size is supported for free if it ever becomes useful). - Linear pid mapping rather than the GROUP_SIZE_M-swizzled mapping from the original kernel: consecutive blocks in the grouped layout belong to different experts so the B-reuse heuristic does not apply. GROUP_SIZE_M is accepted in the config but ignored. Signed-off-by: Matthias Gehre --- .../layers/fused_moe/hybrid_w4a16_grouped.py | 384 ++++++++++++++++++ .../layers/fused_moe/hybrid_w4a16_moe.py | 36 ++ 2 files changed, 420 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/hybrid_w4a16_grouped.py diff --git a/vllm/model_executor/layers/fused_moe/hybrid_w4a16_grouped.py b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_grouped.py new file mode 100644 index 000000000000..7bc9faafdfb3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_grouped.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Grouped-GEMM Triton kernel for shuffle-packed INT4 W4A16 MoE. + +Consumes pre-permuted activations + per-block (expert_id, m_start, +m_count) routing, so the kernel reads contiguous rows of A and never +sees padding blocks. Built to compare against the moe_align_block_size ++ sorted_token_ids path used by ``HybridW4A16MoEExperts.apply``. +""" + +from __future__ import annotations + +import torch + +from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import ( # noqa: F401 (re-export for callers) + direct_register_custom_op, +) + + +@triton.jit +def fused_moe_kernel_hybrid_w4a16_grouped( + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + block_expert_ids_ptr, + block_m_starts_ptr, + block_m_counts_ptr, + N: tl.constexpr, + K: tl.constexpr, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + compute_type: tl.constexpr, +): + """Grouped-GEMM shuffle_w4a16 kernel. + + Parameters + ---------- + a_ptr : pointer to permuted activations [num_routed_tokens, K] + b_ptr : pointer to weights [E, N, K//8] int32 (ExLlama shuffle packed) + c_ptr : pointer to output [num_routed_tokens, N] + b_scale_ptr : pointer to scales [E, N, K//G] (fp16/bf16) + block_expert_ids_ptr : [num_blocks] int32 — expert per block + block_m_starts_ptr : [num_blocks] int32 — start row in A for each block + block_m_counts_ptr : [num_blocks] int32 — valid rows in each block (<= BM) + """ + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + # Linear pid mapping. Consecutive blocks belong to different experts + # in the grouped layout, so the original swizzled GROUP_SIZE_M-style + # B-reuse heuristic does not apply here. GROUP_SIZE_M is accepted but + # ignored. + _ = GROUP_SIZE_M + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + expert_id = tl.load(block_expert_ids_ptr + pid_m).to(tl.int64) + m_start = tl.load(block_m_starts_ptr + pid_m).to(tl.int64) + m_count = tl.load(block_m_counts_ptr + pid_m).to(tl.int32) + + offs_m_local = tl.arange(0, BLOCK_SIZE_M) + row_mask = offs_m_local < m_count + offs_token = m_start + offs_m_local.to(tl.int64) + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_token[:, None] * stride_am + offs_k[None, :] * stride_ak) + + # Shuffle-packed INT4 weight setup (mirrors fused_moe_kernel_gptq_awq). + offs_k8 = tl.arange(0, BLOCK_SIZE_K // 8) + b_packed_ptrs = ( + b_ptr + + expert_id * stride_be + + offs_bn[:, None] * stride_bn + + offs_k8[None, :] * stride_bk + ) + _exl_shifts_row = (tl.arange(0, 8) // 2) * 4 + (tl.arange(0, 8) % 2) * 16 + _exl_shifts_1d = tl.reshape( + tl.broadcast_to(_exl_shifts_row[None, :], (BLOCK_SIZE_K // 8, 8)), + (BLOCK_SIZE_K,), + ) + exl_shifts = tl.broadcast_to(_exl_shifts_1d[None, :], (BLOCK_SIZE_N, BLOCK_SIZE_K)) + b_zp_num = 8 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=row_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + + b_packed = tl.load(b_packed_ptrs) + b_exp = tl.interleave(b_packed, b_packed) + b_exp = tl.interleave(b_exp, b_exp) + b_exp = tl.interleave(b_exp, b_exp) + b_nk = (b_exp >> exl_shifts) & 0xF + b = tl.trans(b_nk) + + if group_size >= BLOCK_SIZE_K: + g_idx = (k * BLOCK_SIZE_K) // group_size + b_scale_ptrs = ( + b_scale_ptr + + expert_id * stride_bse + + offs_bn * stride_bsn + + g_idx * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs).to(tl.float32) + b = ((b.to(tl.float32) - b_zp_num) * b_scale[None, :]).to(compute_type) + else: + b_scale_ptrs = ( + b_scale_ptr + + expert_id * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other).to(tl.float32) + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + + accumulator = tl.dot(a, b, acc=accumulator) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_packed_ptrs += (BLOCK_SIZE_K // 8) * stride_bk + + accumulator = accumulator.to(compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = row_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def build_block_table( + expert_first_token_offset: torch.Tensor, + block_size_m: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Translate per-expert token offsets into a per-block routing table. + + Parameters + ---------- + expert_first_token_offset : [E+1] int64, cumulative tokens per expert + (output of moe_permute). + block_size_m : kernel BLOCK_SIZE_M. + + Returns + ------- + block_expert_ids : [num_blocks] int32 — expert per block + block_m_starts : [num_blocks] int32 — start row in permuted A + block_m_counts : [num_blocks] int32 — valid rows in this block (<= BM) + """ + device = expert_first_token_offset.device + counts64 = expert_first_token_offset[1:] - expert_first_token_offset[:-1] # [E] + blocks_per_expert = ((counts64 + block_size_m - 1) // block_size_m).to(torch.int32) + + cum_blocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + blocks_per_expert.cumsum(0).to(torch.int32), + ] + ) # [E+1] + total_blocks = int(cum_blocks[-1].item()) + + if total_blocks == 0: + empty = torch.empty(0, dtype=torch.int32, device=device) + return empty, empty, empty + + block_indices = torch.arange(total_blocks, dtype=torch.int32, device=device) + # which expert does each block belong to + block_expert_ids = torch.searchsorted(cum_blocks[1:], block_indices, right=True).to( + torch.int32 + ) + # index of the block within its expert (0..blocks_per_expert[e]-1) + block_in_expert = block_indices - cum_blocks[block_expert_ids.long()] + # start row in permuted activations + block_m_starts = ( + expert_first_token_offset[block_expert_ids.long()].to(torch.int32) + + block_in_expert * block_size_m + ) + # how many valid rows in this block (last block of each expert is partial) + block_m_counts = torch.minimum( + torch.full((total_blocks,), block_size_m, dtype=torch.int32, device=device), + counts64[block_expert_ids.long()].to(torch.int32) + - block_in_expert * block_size_m, + ) + return block_expert_ids, block_m_starts, block_m_counts + + +def apply_hybrid_w4a16_grouped( + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int, + activation, + group_size: int, + gemm1_config: dict, + gemm2_config: dict, +) -> None: + """Grouped prefill path: moe_permute -> grouped kernel x2 -> moe_unpermute. + + No moe_align_block_size, no padding, no virtual gather. Activations + are physically permuted into expert-contiguous order; per-block routing + (expert_id, m_start, m_count) is built from expert_first_token_offset + and consumed by the grouped Triton kernel. + """ + from vllm.model_executor.layers.fused_moe.activation import apply_moe_activation + from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, + moe_unpermute, + ) + + M = hidden_states.size(0) + K = hidden_states.size(1) + E = w1.size(0) + N_w1 = w1.size(1) # 2 * intermediate + top_k = topk_ids.size(1) + P = M * top_k + if global_num_experts == -1: + global_num_experts = E + + # Permute hidden_states into expert-contiguous order. + permuted_hidden, _, e_offsets, inv_perm, _ = moe_permute( + hidden_states=hidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=global_num_experts, + ) + + # GEMM 1: permuted [P, K] -> gemm1_out [P, N_w1] + bt1 = build_block_table(e_offsets, gemm1_config["BLOCK_SIZE_M"]) + gemm1_out = torch.empty( + P, + N_w1, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + compute_type = tl.float16 if hidden_states.dtype == torch.float16 else tl.bfloat16 + invoke_fused_moe_kernel_hybrid_w4a16_grouped( + A=permuted_hidden, + B=w1, + C=gemm1_out, + B_scale=w1_scale, + block_expert_ids=bt1[0], + block_m_starts=bt1[1], + block_m_counts=bt1[2], + config=gemm1_config, + compute_type=compute_type, + group_size=group_size, + ) + + # Activation: in-place along N (halves N_w1 -> intermediate). + from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEExpertsModular, + ) + + activation_out_dim = FusedMoEExpertsModular.adjust_N_for_activation( + N_w1, + activation, + ) + act_out = torch.empty( + P, + activation_out_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + apply_moe_activation(activation, act_out, gemm1_out) + + # GEMM 2: act_out [P, intermediate] -> gemm2_out [P, K] + bt2 = build_block_table(e_offsets, gemm2_config["BLOCK_SIZE_M"]) + gemm2_out = torch.empty( + P, + K, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + invoke_fused_moe_kernel_hybrid_w4a16_grouped( + A=act_out, + B=w2, + C=gemm2_out, + B_scale=w2_scale, + block_expert_ids=bt2[0], + block_m_starts=bt2[1], + block_m_counts=bt2[2], + config=gemm2_config, + compute_type=compute_type, + group_size=group_size, + ) + + # Unpermute + topk-weight fold + reduce. + moe_unpermute( + out=output, + permuted_hidden_states=gemm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm, + ) + + +def invoke_fused_moe_kernel_hybrid_w4a16_grouped( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + B_scale: torch.Tensor, + block_expert_ids: torch.Tensor, + block_m_starts: torch.Tensor, + block_m_counts: torch.Tensor, + config: dict, + compute_type, + group_size: int, +) -> None: + assert B.dtype == torch.int32 + assert B_scale is not None and B_scale.ndim == 3 + + K = A.size(1) + N = B.size(1) + num_blocks = block_expert_ids.size(0) + + cfg = config.copy() + BLOCK_SIZE_M = cfg.pop("BLOCK_SIZE_M") + BLOCK_SIZE_N = cfg.pop("BLOCK_SIZE_N") + BLOCK_SIZE_K = cfg.pop("BLOCK_SIZE_K") + GROUP_SIZE_M = cfg.pop("GROUP_SIZE_M") + num_warps = cfg.pop("num_warps") + num_stages = cfg.pop("num_stages") + assert not cfg, f"unexpected config keys: {list(cfg)}" + assert BLOCK_SIZE_K % 8 == 0 + assert group_size >= BLOCK_SIZE_K or BLOCK_SIZE_K % group_size == 0 + + grid = (num_blocks * triton.cdiv(N, BLOCK_SIZE_N),) + fused_moe_kernel_hybrid_w4a16_grouped[grid]( + A, + B, + C, + B_scale, + block_expert_ids, + block_m_starts, + block_m_counts, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + block_k_diviable=K % BLOCK_SIZE_K == 0, + group_size=group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + compute_type=compute_type, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py index 799db13521aa..70953112a96f 100644 --- a/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py +++ b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py @@ -10,6 +10,8 @@ CUDA-graph compatible. """ +import os + import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -307,6 +309,40 @@ def apply( hidden_states, w1, w2, topk_ids ) + # Experimental grouped-GEMM prefill path. Gated off by default. + # Skips moe_align_block_size + virtual gather in favour of a + # physical moe_permute + per-block routing table consumed by a + # dedicated grouped Triton kernel (no padding waste). + if ( + os.environ.get("VLLM_HYBRID_W4A16_GROUPED", "0") == "1" + and num_tokens > self.MAX_SKINNY_BATCH_SIZE + and expert_map is None + and not apply_router_weight_on_input + ): + from vllm.model_executor.layers.fused_moe.hybrid_w4a16_grouped import ( + apply_hybrid_w4a16_grouped, + ) + + activation_out_dim = self.adjust_N_for_activation(N, activation) + apply_hybrid_w4a16_grouped( + output=output, + hidden_states=hidden_states, + w1=w1, + w1_scale=self.w1_scale, + w2=w2, + w2_scale=self.w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + global_num_experts=global_num_experts, + activation=activation, + group_size=self._group_size, + gemm1_config=self._triton_config(K, num_tokens * top_k_num), + gemm2_config=self._triton_config( + activation_out_dim, num_tokens * top_k_num + ), + ) + return + if global_num_experts == -1: global_num_experts = E