diff --git a/benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py b/benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py new file mode 100644 index 000000000000..5cbb7deebab5 --- /dev/null +++ b/benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Kernel-only sweep for the hybrid_w4a16 MoE prefill Triton kernel. + +Calls ``invoke_fused_moe_kernel_hybrid_triton`` directly (skipping the +HybridW4A16MoEExperts.apply host setup so kernel time is not drowned +out, unlike benchmarks/kernels/bench_hybrid_w4a16_moe.py). Sweeps the +full Triton config space and ranks results per gemm. + +Default shapes target Qwen3-Omni-30B-A3B (E=128, group_size=32, +hidden=2048, moe_intermediate=768) so the tune for this model can be +re-derived from scratch. Override with --n / --k / --e / --group-size +to retune for other shapes (e.g. --n 512 --k 2048 --e 256 +--group-size 128 for Qwen3.5-A3B). + +Usage: + python benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py + python benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py \\ + --m 2048 --n 512 --k 2048 --e 256 --topk 8 --group-size 128 \\ + --csv /scratch/mgehre/tmp/sweep_qwen35.csv +""" + +from __future__ import annotations + +import argparse +import csv +import itertools +import sys +from pathlib import Path + +import torch + +REPO = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO)) + +from tests.kernels.moe.test_hybrid_w4a16_moe import ( # noqa: E402 + _make_hybrid_moe_weights, +) +from vllm.model_executor.layers.fused_moe import fused_topk # noqa: E402 +from vllm.model_executor.layers.fused_moe.fused_moe import ( # noqa: E402 + invoke_fused_moe_kernel_hybrid_triton, +) +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( # noqa: E402 + moe_align_block_size, +) +from vllm.triton_utils import tl, triton # noqa: E402 + + +def build_inputs( + M: int, + N_inter: int, + K_hidden: int, + E: int, + topk: int, + group_size: int, + dtype: torch.dtype = torch.float16, + device: str = "cuda", +): + """Build weights, hidden, and topk_ids once for the whole sweep.""" + w1, w1_s, _ = _make_hybrid_moe_weights(E, K_hidden, 2 * N_inter, group_size, device) + w2, w2_s, _ = _make_hybrid_moe_weights(E, N_inter, K_hidden, group_size, device) + hidden = torch.randn(M, K_hidden, device=device, dtype=dtype) / 10.0 + scores = torch.randn(M, E, device=device, dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(hidden, scores, topk, False) + return dict( + w1=w1, + w1_s=w1_s, + w2=w2, + w2_s=w2_s, + hidden=hidden, + topk_ids=topk_ids, + topk_weights=topk_weights, + ) + + +def time_us(fn, warmup_ms: int = 25, rep_ms: int = 80) -> float: + return ( + triton.testing.do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="median") + * 1000.0 + ) + + +def make_gemm1_fn(inputs, cfg, group_size, E): + """Closure that runs only gemm1 (hidden -> 2*N_inter).""" + block_m = cfg["BLOCK_SIZE_M"] + sorted_ids, expert_ids, npp = moe_align_block_size( + inputs["topk_ids"], block_m, E, None, ignore_invalid_experts=True + ) + num_slots = sorted_ids.size(0) + N_w = inputs["w1"].size(1) # 2*N_inter + out = torch.empty( + num_slots, N_w, device=inputs["hidden"].device, dtype=inputs["hidden"].dtype + ) + compute_type = ( + tl.float16 if inputs["hidden"].dtype == torch.float16 else tl.bfloat16 + ) + + def run(): + invoke_fused_moe_kernel_hybrid_triton( + A=inputs["hidden"], + B=inputs["w1"], + C=out, + B_scale=inputs["w1_s"], + topk_weights=None, + sorted_token_ids=sorted_ids, + expert_ids=expert_ids, + num_tokens_post_padded=npp, + mul_routed_weight=False, + top_k=inputs["topk_ids"].size(1), + config=cfg, + compute_type=compute_type, + group_size=group_size, + ) + + return run + + +def make_gemm2_fn(inputs, cfg, group_size, E): + """Closure that runs only gemm2 (N_inter -> hidden). + + Mimics apply()'s second call: A is in slot-space (one row per slot + with the post-activation activations), B is w2, top_k is 1 so the + kernel reads A[slot] directly. + """ + block_m = cfg["BLOCK_SIZE_M"] + sorted_ids, expert_ids, npp = moe_align_block_size( + inputs["topk_ids"], block_m, E, None, ignore_invalid_experts=True + ) + num_slots = sorted_ids.size(0) + # w2: [E, N=K_hidden, K_in=N_inter//8 (int32, holds N_inter//8 int4 elems each)] + K_in = inputs["w2"].size(2) * 8 + K_hidden = inputs["w2"].size(1) + act = ( + torch.randn( + num_slots, + K_in, + device=inputs["hidden"].device, + dtype=inputs["hidden"].dtype, + ) + / 10.0 + ) + out = torch.empty( + num_slots, + K_hidden, + device=inputs["hidden"].device, + dtype=inputs["hidden"].dtype, + ) + compute_type = ( + tl.float16 if inputs["hidden"].dtype == torch.float16 else tl.bfloat16 + ) + + def run(): + invoke_fused_moe_kernel_hybrid_triton( + A=act, + B=inputs["w2"], + C=out, + B_scale=inputs["w2_s"], + topk_weights=None, + sorted_token_ids=sorted_ids, + expert_ids=expert_ids, + num_tokens_post_padded=npp, + mul_routed_weight=False, + top_k=1, + config=cfg, + compute_type=compute_type, + group_size=group_size, + ) + + return run + + +def sweep(args): + inputs = build_inputs(args.m, args.n, args.k, args.e, args.topk, args.group_size) + torch.accelerator.synchronize() + + # Each (BM, BN, BK, GM, nw, ns) candidate. BK is capped to + # group_size inside the wrapper, so listing values > group_size is + # equivalent to BK=group_size. + bm_list = [int(x) for x in args.block_m] + bn_list = [int(x) for x in args.block_n] + bk_list = [int(x) for x in args.block_k] + gm_list = [int(x) for x in args.group_m] + nw_list = [int(x) for x in args.num_warps] + ns_list = [int(x) for x in args.num_stages] + + gemms: list[tuple[str, callable]] = [] + if "gemm1" in args.gemms: + gemms.append(("gemm1", make_gemm1_fn)) + if "gemm2" in args.gemms: + gemms.append(("gemm2", make_gemm2_fn)) + + rows: list[dict] = [] + + for gname, maker in gemms: + print( + f"\n=== {gname} sweep " + f"(M={args.m}, N={args.n}, K={args.k}, E={args.e}, " + f"topk={args.topk}, group_size={args.group_size}) ===" + ) + results: list[tuple[dict, float]] = [] + for bm, bn, bk, gm, nw, ns in itertools.product( + bm_list, bn_list, bk_list, gm_list, nw_list, ns_list + ): + cfg = dict( + BLOCK_SIZE_M=bm, + BLOCK_SIZE_N=bn, + BLOCK_SIZE_K=bk, + GROUP_SIZE_M=gm, + num_warps=nw, + num_stages=ns, + ) + try: + fn = maker(inputs, cfg, args.group_size, args.e) + fn() # warmup + correctness sanity + torch.accelerator.synchronize() + t = time_us(fn) + except Exception as e: + print(f" SKIP {cfg}: {type(e).__name__}: {e}") + continue + results.append((cfg, t)) + rows.append({"gemm": gname, **cfg, "us": t}) + + results.sort(key=lambda x: x[1]) + ref_cfg, ref_t = results[0] + print( + f" {'rank':>4} {'BM':>4} {'BN':>4} {'BK':>4} {'GM':>4} " + f"{'nw':>3} {'ns':>3} {'us':>9} {'vs best':>8}" + ) + for i, (cfg, t) in enumerate(results): + mark = "*" if i == 0 else " " + print( + f" {i + 1:>4}{mark} {cfg['BLOCK_SIZE_M']:>4} " + f"{cfg['BLOCK_SIZE_N']:>4} {cfg['BLOCK_SIZE_K']:>4} " + f"{cfg['GROUP_SIZE_M']:>4} {cfg['num_warps']:>3} " + f"{cfg['num_stages']:>3} {t:>9.1f} {t / ref_t:>7.2f}x" + ) + + if args.csv: + with open(args.csv, "w", newline="") as f: + w = csv.DictWriter( + f, + fieldnames=[ + "gemm", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "num_warps", + "num_stages", + "us", + ], + ) + w.writeheader() + w.writerows(rows) + print(f"\nCSV: {args.csv}") + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + # Default = Qwen3-Omni-30B-A3B thinker shape. + p.add_argument("--m", type=int, default=2048, help="Prefill batch (rows)") + p.add_argument( + "--n", + type=int, + default=768, + help="moe_intermediate_size (single, not 2x). 768 for Qwen3-Omni, " + "512 for Qwen3.5-A3B.", + ) + p.add_argument("--k", type=int, default=2048, help="hidden_size") + p.add_argument( + "--e", + type=int, + default=128, + help="num_experts. 128 for Qwen3-Omni, 256 for Qwen3.5-A3B.", + ) + p.add_argument("--topk", type=int, default=8, help="experts per token") + p.add_argument( + "--group-size", + type=int, + default=32, + help="W4A16 group size. 32 for Qwen3-Omni, 128 for Qwen3.5-A3B.", + ) + p.add_argument( + "--gemms", + nargs="+", + default=["gemm1", "gemm2"], + choices=["gemm1", "gemm2"], + ) + p.add_argument("--block-m", nargs="+", default=[16, 32, 64, 128]) + p.add_argument("--block-n", nargs="+", default=[16, 32, 64, 128]) + p.add_argument( + "--block-k", + nargs="+", + default=[32, 64, 128], + help="Capped to group_size inside the wrapper", + ) + p.add_argument("--group-m", nargs="+", default=[1, 4, 8]) + p.add_argument("--num-warps", nargs="+", default=[2, 4, 8]) + p.add_argument("--num-stages", nargs="+", default=[1, 2]) + p.add_argument( + "--csv", type=str, default=None, help="Write all results to this CSV path" + ) + args = p.parse_args() + sweep(args) + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/moe/test_hybrid_w4a16_moe.py b/tests/kernels/moe/test_hybrid_w4a16_moe.py index 77f2b8fe4352..4fdbc0741376 100644 --- a/tests/kernels/moe/test_hybrid_w4a16_moe.py +++ b/tests/kernels/moe/test_hybrid_w4a16_moe.py @@ -256,7 +256,8 @@ def test_hybrid_w4a16_moe_force_triton( hybrid_out, torch_output = _run_hybrid_moe( m, n, k, e, topk, group_size, force_triton=True ) - torch.testing.assert_close(hybrid_out, torch_output, atol=2e-2, rtol=0) + # gs<=64 _triton_config branch reorders fp16 reductions; needs 3e-2. + torch.testing.assert_close(hybrid_out, torch_output, atol=3e-2, rtol=0) @pytest.mark.skipif( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 73eff820d207..088d5593fd4d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -279,17 +279,29 @@ def fused_moe_kernel_gptq_awq( b_nk = (b_exp >> exl_shifts) & 0xF # [BLOCK_N, BLOCK_K] b = tl.trans(b_nk) # [BLOCK_K, BLOCK_N] - # Scales: [E, N, K//G] — load per-group scale for this K tile - g_idx = (k * BLOCK_SIZE_K) // group_size - b_scale_ptrs = ( - b_scale_ptr - + off_experts * stride_bse - + offs_bn * stride_bsn - + g_idx * stride_bsk - ) - b_scale = tl.load(b_scale_ptrs).to(tl.float32) - # Dequant: (nibble - 8) * scale - b = ((b.to(tl.float32) - b_zp_num) * b_scale[None, :]).to(compute_type) + # Scales: [E, N, K//G]. Two constexpr paths — Triton does not + # broadcast a degenerate [BLOCK_K, BLOCK_N] load back to a + # scalar reliably, so BLOCK_K <= group_size keeps its own + # [BLOCK_N] load (one scale per tile) for full speed. + if group_size >= BLOCK_SIZE_K: + g_idx = (k * BLOCK_SIZE_K) // group_size + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn * stride_bsn + + g_idx * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs).to(tl.float32) + b = ((b.to(tl.float32) - b_zp_num) * b_scale[None, :]).to(compute_type) + else: + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs).to(tl.float32) # [BLOCK_K, BLOCK_N] + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) else: b = tl.load(b_ptrs) if use_int4_w4a16: @@ -817,12 +829,15 @@ def invoke_fused_moe_kernel_hybrid_triton( config = config.copy() config["SPLIT_K"] = 1 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") - # BLOCK_K must be multiple of 8 for ExLlama shuffle interleave + # BLOCK_K must be multiple of 8 for ExLlama shuffle interleave. assert BLOCK_SIZE_K % 8 == 0 - # BLOCK_K must not exceed group_size (one scale per K-tile) - BLOCK_SIZE_K = min(BLOCK_SIZE_K, group_size) - assert BLOCK_SIZE_K % 8 == 0, ( - f"group_size {group_size} must be a multiple of 8 for shuffle kernel" + # BLOCK_K must tile cleanly across the per-group scale layout: + # either BK <= group_size (multiple tiles share one scale) or + # BK >= group_size with BK % group_size == 0 (each tile spans an + # integer number of groups; scales loaded per K-row inside the kernel). + assert group_size >= BLOCK_SIZE_K or BLOCK_SIZE_K % group_size == 0, ( + f"BLOCK_SIZE_K ({BLOCK_SIZE_K}) must be <= group_size ({group_size}) " + f"or a multiple of group_size" ) with record_function_or_nullcontext( diff --git a/vllm/model_executor/layers/fused_moe/hybrid_w4a16_grouped.py b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_grouped.py new file mode 100644 index 000000000000..7bc9faafdfb3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_grouped.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Grouped-GEMM Triton kernel for shuffle-packed INT4 W4A16 MoE. + +Consumes pre-permuted activations + per-block (expert_id, m_start, +m_count) routing, so the kernel reads contiguous rows of A and never +sees padding blocks. Built to compare against the moe_align_block_size ++ sorted_token_ids path used by ``HybridW4A16MoEExperts.apply``. +""" + +from __future__ import annotations + +import torch + +from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import ( # noqa: F401 (re-export for callers) + direct_register_custom_op, +) + + +@triton.jit +def fused_moe_kernel_hybrid_w4a16_grouped( + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + block_expert_ids_ptr, + block_m_starts_ptr, + block_m_counts_ptr, + N: tl.constexpr, + K: tl.constexpr, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + compute_type: tl.constexpr, +): + """Grouped-GEMM shuffle_w4a16 kernel. + + Parameters + ---------- + a_ptr : pointer to permuted activations [num_routed_tokens, K] + b_ptr : pointer to weights [E, N, K//8] int32 (ExLlama shuffle packed) + c_ptr : pointer to output [num_routed_tokens, N] + b_scale_ptr : pointer to scales [E, N, K//G] (fp16/bf16) + block_expert_ids_ptr : [num_blocks] int32 — expert per block + block_m_starts_ptr : [num_blocks] int32 — start row in A for each block + block_m_counts_ptr : [num_blocks] int32 — valid rows in each block (<= BM) + """ + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + # Linear pid mapping. Consecutive blocks belong to different experts + # in the grouped layout, so the original swizzled GROUP_SIZE_M-style + # B-reuse heuristic does not apply here. GROUP_SIZE_M is accepted but + # ignored. + _ = GROUP_SIZE_M + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + expert_id = tl.load(block_expert_ids_ptr + pid_m).to(tl.int64) + m_start = tl.load(block_m_starts_ptr + pid_m).to(tl.int64) + m_count = tl.load(block_m_counts_ptr + pid_m).to(tl.int32) + + offs_m_local = tl.arange(0, BLOCK_SIZE_M) + row_mask = offs_m_local < m_count + offs_token = m_start + offs_m_local.to(tl.int64) + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_token[:, None] * stride_am + offs_k[None, :] * stride_ak) + + # Shuffle-packed INT4 weight setup (mirrors fused_moe_kernel_gptq_awq). + offs_k8 = tl.arange(0, BLOCK_SIZE_K // 8) + b_packed_ptrs = ( + b_ptr + + expert_id * stride_be + + offs_bn[:, None] * stride_bn + + offs_k8[None, :] * stride_bk + ) + _exl_shifts_row = (tl.arange(0, 8) // 2) * 4 + (tl.arange(0, 8) % 2) * 16 + _exl_shifts_1d = tl.reshape( + tl.broadcast_to(_exl_shifts_row[None, :], (BLOCK_SIZE_K // 8, 8)), + (BLOCK_SIZE_K,), + ) + exl_shifts = tl.broadcast_to(_exl_shifts_1d[None, :], (BLOCK_SIZE_N, BLOCK_SIZE_K)) + b_zp_num = 8 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=row_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + + b_packed = tl.load(b_packed_ptrs) + b_exp = tl.interleave(b_packed, b_packed) + b_exp = tl.interleave(b_exp, b_exp) + b_exp = tl.interleave(b_exp, b_exp) + b_nk = (b_exp >> exl_shifts) & 0xF + b = tl.trans(b_nk) + + if group_size >= BLOCK_SIZE_K: + g_idx = (k * BLOCK_SIZE_K) // group_size + b_scale_ptrs = ( + b_scale_ptr + + expert_id * stride_bse + + offs_bn * stride_bsn + + g_idx * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs).to(tl.float32) + b = ((b.to(tl.float32) - b_zp_num) * b_scale[None, :]).to(compute_type) + else: + b_scale_ptrs = ( + b_scale_ptr + + expert_id * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other).to(tl.float32) + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + + accumulator = tl.dot(a, b, acc=accumulator) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_packed_ptrs += (BLOCK_SIZE_K // 8) * stride_bk + + accumulator = accumulator.to(compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = row_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def build_block_table( + expert_first_token_offset: torch.Tensor, + block_size_m: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Translate per-expert token offsets into a per-block routing table. + + Parameters + ---------- + expert_first_token_offset : [E+1] int64, cumulative tokens per expert + (output of moe_permute). + block_size_m : kernel BLOCK_SIZE_M. + + Returns + ------- + block_expert_ids : [num_blocks] int32 — expert per block + block_m_starts : [num_blocks] int32 — start row in permuted A + block_m_counts : [num_blocks] int32 — valid rows in this block (<= BM) + """ + device = expert_first_token_offset.device + counts64 = expert_first_token_offset[1:] - expert_first_token_offset[:-1] # [E] + blocks_per_expert = ((counts64 + block_size_m - 1) // block_size_m).to(torch.int32) + + cum_blocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + blocks_per_expert.cumsum(0).to(torch.int32), + ] + ) # [E+1] + total_blocks = int(cum_blocks[-1].item()) + + if total_blocks == 0: + empty = torch.empty(0, dtype=torch.int32, device=device) + return empty, empty, empty + + block_indices = torch.arange(total_blocks, dtype=torch.int32, device=device) + # which expert does each block belong to + block_expert_ids = torch.searchsorted(cum_blocks[1:], block_indices, right=True).to( + torch.int32 + ) + # index of the block within its expert (0..blocks_per_expert[e]-1) + block_in_expert = block_indices - cum_blocks[block_expert_ids.long()] + # start row in permuted activations + block_m_starts = ( + expert_first_token_offset[block_expert_ids.long()].to(torch.int32) + + block_in_expert * block_size_m + ) + # how many valid rows in this block (last block of each expert is partial) + block_m_counts = torch.minimum( + torch.full((total_blocks,), block_size_m, dtype=torch.int32, device=device), + counts64[block_expert_ids.long()].to(torch.int32) + - block_in_expert * block_size_m, + ) + return block_expert_ids, block_m_starts, block_m_counts + + +def apply_hybrid_w4a16_grouped( + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int, + activation, + group_size: int, + gemm1_config: dict, + gemm2_config: dict, +) -> None: + """Grouped prefill path: moe_permute -> grouped kernel x2 -> moe_unpermute. + + No moe_align_block_size, no padding, no virtual gather. Activations + are physically permuted into expert-contiguous order; per-block routing + (expert_id, m_start, m_count) is built from expert_first_token_offset + and consumed by the grouped Triton kernel. + """ + from vllm.model_executor.layers.fused_moe.activation import apply_moe_activation + from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, + moe_unpermute, + ) + + M = hidden_states.size(0) + K = hidden_states.size(1) + E = w1.size(0) + N_w1 = w1.size(1) # 2 * intermediate + top_k = topk_ids.size(1) + P = M * top_k + if global_num_experts == -1: + global_num_experts = E + + # Permute hidden_states into expert-contiguous order. + permuted_hidden, _, e_offsets, inv_perm, _ = moe_permute( + hidden_states=hidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=global_num_experts, + ) + + # GEMM 1: permuted [P, K] -> gemm1_out [P, N_w1] + bt1 = build_block_table(e_offsets, gemm1_config["BLOCK_SIZE_M"]) + gemm1_out = torch.empty( + P, + N_w1, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + compute_type = tl.float16 if hidden_states.dtype == torch.float16 else tl.bfloat16 + invoke_fused_moe_kernel_hybrid_w4a16_grouped( + A=permuted_hidden, + B=w1, + C=gemm1_out, + B_scale=w1_scale, + block_expert_ids=bt1[0], + block_m_starts=bt1[1], + block_m_counts=bt1[2], + config=gemm1_config, + compute_type=compute_type, + group_size=group_size, + ) + + # Activation: in-place along N (halves N_w1 -> intermediate). + from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEExpertsModular, + ) + + activation_out_dim = FusedMoEExpertsModular.adjust_N_for_activation( + N_w1, + activation, + ) + act_out = torch.empty( + P, + activation_out_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + apply_moe_activation(activation, act_out, gemm1_out) + + # GEMM 2: act_out [P, intermediate] -> gemm2_out [P, K] + bt2 = build_block_table(e_offsets, gemm2_config["BLOCK_SIZE_M"]) + gemm2_out = torch.empty( + P, + K, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + invoke_fused_moe_kernel_hybrid_w4a16_grouped( + A=act_out, + B=w2, + C=gemm2_out, + B_scale=w2_scale, + block_expert_ids=bt2[0], + block_m_starts=bt2[1], + block_m_counts=bt2[2], + config=gemm2_config, + compute_type=compute_type, + group_size=group_size, + ) + + # Unpermute + topk-weight fold + reduce. + moe_unpermute( + out=output, + permuted_hidden_states=gemm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm, + ) + + +def invoke_fused_moe_kernel_hybrid_w4a16_grouped( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + B_scale: torch.Tensor, + block_expert_ids: torch.Tensor, + block_m_starts: torch.Tensor, + block_m_counts: torch.Tensor, + config: dict, + compute_type, + group_size: int, +) -> None: + assert B.dtype == torch.int32 + assert B_scale is not None and B_scale.ndim == 3 + + K = A.size(1) + N = B.size(1) + num_blocks = block_expert_ids.size(0) + + cfg = config.copy() + BLOCK_SIZE_M = cfg.pop("BLOCK_SIZE_M") + BLOCK_SIZE_N = cfg.pop("BLOCK_SIZE_N") + BLOCK_SIZE_K = cfg.pop("BLOCK_SIZE_K") + GROUP_SIZE_M = cfg.pop("GROUP_SIZE_M") + num_warps = cfg.pop("num_warps") + num_stages = cfg.pop("num_stages") + assert not cfg, f"unexpected config keys: {list(cfg)}" + assert BLOCK_SIZE_K % 8 == 0 + assert group_size >= BLOCK_SIZE_K or BLOCK_SIZE_K % group_size == 0 + + grid = (num_blocks * triton.cdiv(N, BLOCK_SIZE_N),) + fused_moe_kernel_hybrid_w4a16_grouped[grid]( + A, + B, + C, + B_scale, + block_expert_ids, + block_m_starts, + block_m_counts, + N, + K, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + block_k_diviable=K % BLOCK_SIZE_K == 0, + group_size=group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + compute_type=compute_type, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py index 6f14f0a9e722..70953112a96f 100644 --- a/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py +++ b/vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py @@ -10,6 +10,8 @@ CUDA-graph compatible. """ +import os + import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -131,16 +133,26 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # neutral (not measured in this PR). TRITON_BLOCK_SIZE_M = 32 - @staticmethod - def _select_block_size_m(num_tokens: int, topk: int, E: int) -> int: + # Alignment for the group_size <= SMALL_GROUP_SIZE_THRESHOLD branch + # of _triton_config. Must be lcm of the BLOCK_SIZE_M values that + # branch emits (128 for gemm1, 64 for gemm2). + TRITON_BLOCK_SIZE_M_SMALL_GS = 128 + + # group_size <= THRESHOLD uses the small-gs _triton_config branch; + # > THRESHOLD uses the default. Boundary between 32 and 128. + SMALL_GROUP_SIZE_THRESHOLD = 64 + + def _select_block_size_m(self, num_tokens: int, topk: int, E: int) -> int: """Select block size in the M dimension. - Decode (num_tokens <= MAX_SKINNY_BATCH_SIZE): use small block sizes - compatible with the wvSplitK_int4 HIP kernel (N=1..5). - Prefill (num_tokens > MAX_SKINNY_BATCH_SIZE): use the Triton kernel's - BLOCK_SIZE_M for efficient batched GEMM. + Decode (num_tokens <= MAX_SKINNY_BATCH_SIZE): small block sizes + compatible with the wvSplitK_int4 HIP kernel. + Prefill: TRITON_BLOCK_SIZE_M, or TRITON_BLOCK_SIZE_M_SMALL_GS + for small group_size. """ if num_tokens > HybridW4A16MoEExperts.MAX_SKINNY_BATCH_SIZE: + if self._group_size <= HybridW4A16MoEExperts.SMALL_GROUP_SIZE_THRESHOLD: + return HybridW4A16MoEExperts.TRITON_BLOCK_SIZE_M_SMALL_GS return HybridW4A16MoEExperts.TRITON_BLOCK_SIZE_M if num_tokens > 1: avg = num_tokens * topk / E @@ -213,6 +225,27 @@ def _triton_config( BLOCK_SIZE_K = self._group_size # = 128 for the Qwen3.5-A3B path assert BLOCK_SIZE_K % 8 == 0 + if self._group_size <= HybridW4A16MoEExperts.SMALL_GROUP_SIZE_THRESHOLD: + # gemm2 BM=64 < alignment=128; apply()'s _expert_ids_for + # repeat_interleaves expert_ids to compensate. + if K < 1024: + return dict( + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + num_warps=4, + num_stages=1, + ) + return dict( + BLOCK_SIZE_M=128, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=1, + num_warps=8, + num_stages=1, + ) + # Per-shape sweep on Strix Halo (gfx1151) at BLOCK_M=32 (the # current alignment), using benchmarks/kernels/ # sweep_int4g_moe_kernel.py + a per-shape direct call to @@ -276,6 +309,40 @@ def apply( hidden_states, w1, w2, topk_ids ) + # Experimental grouped-GEMM prefill path. Gated off by default. + # Skips moe_align_block_size + virtual gather in favour of a + # physical moe_permute + per-block routing table consumed by a + # dedicated grouped Triton kernel (no padding waste). + if ( + os.environ.get("VLLM_HYBRID_W4A16_GROUPED", "0") == "1" + and num_tokens > self.MAX_SKINNY_BATCH_SIZE + and expert_map is None + and not apply_router_weight_on_input + ): + from vllm.model_executor.layers.fused_moe.hybrid_w4a16_grouped import ( + apply_hybrid_w4a16_grouped, + ) + + activation_out_dim = self.adjust_N_for_activation(N, activation) + apply_hybrid_w4a16_grouped( + output=output, + hidden_states=hidden_states, + w1=w1, + w1_scale=self.w1_scale, + w2=w2, + w2_scale=self.w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + global_num_experts=global_num_experts, + activation=activation, + group_size=self._group_size, + gemm1_config=self._triton_config(K, num_tokens * top_k_num), + gemm2_config=self._triton_config( + activation_out_dim, num_tokens * top_k_num + ), + ) + return + if global_num_experts == -1: global_num_experts = E