From 8bf834693b116c34cb180a69deab7b753f3c6252 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Mon, 1 Jun 2026 23:01:38 +0000 Subject: [PATCH 01/18] grouped topk kernel with sigmoid for deepseek r1 integration --- .../moe/moe_routing/grouped_topk_triton.py | 451 +++++++++++++ aiter/ops/triton/moe/moe_routing/routing.py | 81 ++- .../triton_tests/moe/test_grouped_topk.py | 603 ++++++++++++++++++ 3 files changed, 1120 insertions(+), 15 deletions(-) create mode 100644 aiter/ops/triton/moe/moe_routing/grouped_topk_triton.py create mode 100644 op_tests/triton_tests/moe/test_grouped_topk.py diff --git a/aiter/ops/triton/moe/moe_routing/grouped_topk_triton.py b/aiter/ops/triton/moe/moe_routing/grouped_topk_triton.py new file mode 100644 index 0000000000..297793584e --- /dev/null +++ b/aiter/ops/triton/moe/moe_routing/grouped_topk_triton.py @@ -0,0 +1,451 @@ +# SPDX-License-Identifier: MIT +"""Single-fused Triton grouped-top-k routing kernel. + +Drop-in replacement for the ``topk(...)`` call inside +``aiter/ops/triton/moe/moe_routing/routing.py::routing_a8w4`` (lines 338-347). +Same return contract — ``(y_vals, y_indx, Bitmatrix)`` — so downstream +``sort_tokens`` / ``sort_tokens_fused`` consume the output unchanged. + +Algorithm (single kernel launch, mirrors the structure of aiter's ``_topk`` +and ``_hash_routing`` in ``_triton_kernels/moe/moe_routing/topk.py``): + + 1. Memset bitmatrix scratchpad / partials (same lane-borrowing trick as + ``_topk``: the first ``s_blocks + sp_blocks`` programs do nothing but + zero-fill). + 2. Load the row of router logits. + 3. Apply ``score_mode`` per element ('softmax' / 'sigmoid' / 'sqrtsoftplus' / + 'none'). + 4. Per-group score reduction over an *arbitrary* expert→group mapping + (``ExpertGroup`` int32 table): + - HAS_BIAS → top-2 sum on bias-augmented scores (DeepSeek-V3 rule; + mirrors ``biased_grouped_topk_torch``). + - else → per-group max (DeepSeek-V2 rule; mirrors + ``grouped_topk_torch``). + 5. Pick top ``TOPK_GROUP`` groups via repeated argmax (NUM_EXPERT_GROUP is + small, so the unrolled loop is tiny). + 6. Mask experts in non-selected groups to ``-inf`` on the bias-augmented + scores, then do per-expert top-``N_EXPTS_ACT`` via repeated argmax. + 7. Gather *unbiased* weights at the selected indices (matches the + ``noaux_tc`` semantics — bias used for selection only, weights from the + untouched score). + 8. Optional renorm + ``routed_scaling_factor`` scale. + 9. Pack selected indices into the (n_cols_words, n_rows_pad32).T uint32 + bitmatrix layout the kernel emits, identical to ``_topk``. + +Constraints (DeepSeek-class envelope): + - n_expts_tot ≤ 256 (single ``BLOCK_N`` pass; no streaming loop). + - num_expert_group ≤ 16. + - topk_group ≤ num_expert_group. + - n_expts_act (top_k) ≤ 16. + - BLOCK_M = 1 (the per-group 3-D intermediate is + ``[BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP]`` fp32 — at BLOCK_M=1 that's + ≤ 256 * 16 * 4 = 16 KiB, fits in registers / LDS comfortably). +""" +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +from aiter.ops.triton.moe.moe_routing.bitmatrix import Bitmatrix + + +@triton.jit +def _grouped_topk( + X, # router logits [n_rows, n_expts_tot] (bf16/fp32) + stride_xm, + ExpertGroup, # int32 [n_expts_tot] expert→group_id + Yv, # [n_rows, N_EXPTS_ACT_PAD] selected weights + Yi, # [n_rows, N_EXPTS_ACT_PAD] selected expert ids (int16) + stride_ym, + Bits, # bitmatrix data + stride_rm, + stride_rn, + n_rows, + n_expts_tot, + S, # bitmatrix scratchpad — must memset to 0 + BLOCK_S: tl.constexpr, + s_blocks, + SP, # bitmatrix partials — must memset to 0 + BLOCK_SP: tl.constexpr, + sp_blocks, + sp_size, + BLOCK_M: tl.constexpr, + N_EXPTS_PAD: tl.constexpr, + BLOCK_N: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, + N_EXPTS_ACT_PAD: tl.constexpr, + NUM_EXPERT_GROUP: tl.constexpr, + TOPK_GROUP: tl.constexpr, + Bias=None, + SCORE_MODE: tl.constexpr = "softmax", + HAS_BIAS: tl.constexpr = False, + APPLY_RENORM: tl.constexpr = False, + ROUTED_SCALING: tl.constexpr = 1.0, + N_SHARED: tl.constexpr = 0, + SHARED_SCORE: tl.constexpr = 1.0, +): + pid = tl.program_id(0) + + # -- Memset bitmatrix scratchpads (same idiom as _topk / _hash_routing). + if pid < s_blocks: + tl.store( + S + BLOCK_S * pid + tl.arange(0, BLOCK_S), + tl.zeros([BLOCK_S], tl.int32), + ) + elif pid < s_blocks + sp_blocks: + offs = BLOCK_SP * (pid - s_blocks) + tl.arange(0, BLOCK_SP) + tl.store(SP + offs, tl.zeros([BLOCK_SP], tl.int32), mask=offs < sp_size) + + if pid * BLOCK_M >= n_rows: + return + + tl.static_assert(BLOCK_N % 32 == 0) + tl.static_assert( + N_EXPTS_PAD == BLOCK_N, + "DeepSeek-class envelope: BLOCK_N must equal N_EXPTS_PAD (single-block).", + ) + + x_dtype: tl.constexpr = X.dtype.element_ty + + offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = offs_m < n_rows + offs_n = tl.arange(0, BLOCK_N) + mask_n = offs_n < n_expts_tot + + # -- 1. Load logits. + X_ptrs = X + offs_m[:, None] * stride_xm + offs_n[None, :] + x = tl.load(X_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0) + + # -- 2. Score transform. + if SCORE_MODE == "softmax": + # Numerically-stable row softmax with masked-out lanes set to -inf. + x_f = tl.where(mask_n[None, :], x.to(tl.float32), float("-inf")) + x_max = tl.max(x_f, axis=1, keep_dims=True) + x_e = tl.exp(x_f - x_max) + x_e = tl.where(mask_n[None, :], x_e, 0.0) + scores = x_e / (tl.sum(x_e, axis=1, keep_dims=True) + 1e-30) + elif SCORE_MODE == "sigmoid": + scores = 1.0 / (1.0 + tl.exp(-x.to(tl.float32))) + elif SCORE_MODE == "sqrtsoftplus": + x_f = x.to(tl.float32) + sp = tl.maximum(x_f, 0.0) + tl.log(1.0 + tl.exp(-tl.abs(x_f))) + scores = tl.sqrt(sp) + else: + scores = x.to(tl.float32) + + # Pad-lane safety: invalid columns must lose every comparison. + scores = tl.where(mask_n[None, :], scores, float("-inf")) + + # -- 3. Bias-augmented choice scores. Weights are gathered later from the + # untouched ``scores`` (matches biased_grouped_topk_torch + + # FusedMoE.select_experts sigmoid path: select on s+b, return s). + if HAS_BIAS: + b = tl.load(Bias + offs_n, mask=mask_n, other=0.0).to(tl.float32) + scores_for_choice = scores + b[None, :] + else: + scores_for_choice = scores + + # -- 4. Per-group reduction over arbitrary expert→group mapping. + gid = tl.load(ExpertGroup + offs_n, mask=mask_n, other=0).to(tl.int32) + g_arange = tl.arange(0, NUM_EXPERT_GROUP) + gid_eq = gid[:, None] == g_arange[None, :] # [BLOCK_N, NUM_EXPERT_GROUP] + + # 3-D one-hot expand: [BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP], with -inf + # outside each group's column. + sfc_3d = scores_for_choice[:, :, None].broadcast_to( + BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP + ) + expanded = tl.where(gid_eq[None, :, :], sfc_3d, float("-inf")) + group_max1 = tl.max(expanded, axis=1) # [BLOCK_M, NUM_EXPERT_GROUP] + + if HAS_BIAS: + # Top-2-sum-per-group. To find the second-largest score per group + # without tl.argmax-on-3D, suppress the per-group max by exact-equality + # match (ties on float scores are negligible in DeepSeek workloads). + gm1_per_e = tl.sum( + gid_eq[None, :, :].to(tl.float32) * group_max1[:, None, :], + axis=2, + ) # [BLOCK_M, BLOCK_N] + suppressed = tl.where( + scores_for_choice == gm1_per_e, float("-inf"), scores_for_choice + ) + sup_3d = suppressed[:, :, None].broadcast_to( + BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP + ) + expanded2 = tl.where(gid_eq[None, :, :], sup_3d, float("-inf")) + group_max2 = tl.max(expanded2, axis=1) + group_scores = group_max1 + group_max2 + else: + group_scores = group_max1 + + # -- 5. Top ``TOPK_GROUP`` groups via repeated argmax (NUM_EXPERT_GROUP + # is small; static-range unroll). + group_mask_i = tl.zeros([BLOCK_M, NUM_EXPERT_GROUP], dtype=tl.int32) + gs = group_scores + for _gj in tl.static_range(TOPK_GROUP): + am_g = tl.argmax(gs, axis=1).to(tl.int32) # [BLOCK_M] + sel_g = (g_arange[None, :] == am_g[:, None]) # [BLOCK_M, NUM_EXPERT_GROUP] + group_mask_i = group_mask_i | sel_g.to(tl.int32) + gs = tl.where(sel_g, float("-inf"), gs) + + # -- 6. Per-(token, expert) keep-mask via group-id lookup, then suppress + # experts in non-selected groups on the bias-augmented scores. + expert_keep = tl.sum( + gid_eq[None, :, :].to(tl.int32) * group_mask_i[:, None, :], + axis=2, + ) > 0 # [BLOCK_M, BLOCK_N] + sfc_masked = tl.where(expert_keep, scores_for_choice, float("-inf")) + + # -- 7. Per-expert top-``N_EXPTS_ACT`` via repeated argmax. Padded slots + # (N_EXPTS_ACT_PAD > N_EXPTS_ACT) are kept in the y_indices/y_values + # buffers but masked off on the writeback / bitmatrix-pack. + n_arange = tl.arange(0, BLOCK_N) + y_indices = tl.zeros([BLOCK_M, N_EXPTS_ACT_PAD], dtype=tl.int32) + sfc_iter = sfc_masked + for kj in tl.static_range(N_EXPTS_ACT): + am_k = tl.argmax(sfc_iter, axis=1).to(tl.int32) # [BLOCK_M] + slot_eq = (tl.arange(0, N_EXPTS_ACT_PAD) == kj)[None, :] + y_indices = tl.where(slot_eq, am_k[:, None], y_indices) + sfc_iter = tl.where( + n_arange[None, :] == am_k[:, None], float("-inf"), sfc_iter + ) + + # -- 8. Gather UNBIASED weights at selected indices. + pos_eq = ( + n_arange[None, None, :] == y_indices[:, :, None] + ) # [BLOCK_M, K_PAD, BLOCK_N] + scores_3d = scores[:, None, :].broadcast_to(BLOCK_M, N_EXPTS_ACT_PAD, BLOCK_N) + y_weights = tl.sum(tl.where(pos_eq, scores_3d, 0.0), axis=2) # [BLOCK_M, K_PAD] + + # Routed-slot mask: the first N_EXPTS_ACT slots hold the grouped-topk + # selection (shared experts, if any, occupy the next N_SHARED slots and + # must be excluded from the routed renorm denominator). + k_arange = tl.arange(0, N_EXPTS_ACT_PAD) + routed_mask = k_arange[None, :] < N_EXPTS_ACT + + # -- 9. Renorm + scale over the ROUTED slots only (mirrors _topk's + # APPLY_RENORM / ROUTED_SCALING and the noaux_tc semantics where the + # always-on shared expert is appended unscaled after renorm). + if APPLY_RENORM: + y_f = tl.where(routed_mask, y_weights, 0.0) + s = tl.sum(y_f, axis=1, keep_dims=True) + y_weights = y_f / (s + 1e-20) * ROUTED_SCALING + elif ROUTED_SCALING != 1.0: + y_weights = y_weights * ROUTED_SCALING + + # -- 9b. Append fused shared expert(s): always-on, fixed id n_expts_tot+i + # and fixed weight SHARED_SCORE (matches init_aiter_topK_meta_data / + # rocm_aiter_grouped_topk). Placed AFTER renorm so the shared weight + # is not folded into the routed normalization. + if N_SHARED > 0: + shared_slot = (k_arange[None, :] >= N_EXPTS_ACT) & ( + k_arange[None, :] < N_EXPTS_ACT + N_SHARED + ) + shared_idx = (n_expts_tot + k_arange - N_EXPTS_ACT)[None, :].to(tl.int32) + y_indices = tl.where(shared_slot, shared_idx, y_indices) + y_weights = tl.where(shared_slot, SHARED_SCORE, y_weights) + real_mask = k_arange[None, :] < (N_EXPTS_ACT + N_SHARED) + else: + real_mask = routed_mask + + y_values_out = y_weights.to(x_dtype) + + # -- 10. Writeback selected weights / indices. + Yv_ptrs = Yv + offs_m[:, None] * stride_ym + k_arange[None, :] + Yi_ptrs = Yi + offs_m[:, None] * stride_ym + k_arange[None, :] + write_mask = mask_m[:, None] & real_mask + tl.store(Yv_ptrs, y_values_out, mask=write_mask) + tl.store(Yi_ptrs, y_indices, mask=write_mask) + + # -- 11. Pack into bitmatrix (mirrors _topk's tail). + safe_idx = tl.where(real_mask, y_indices, 0).to(tl.uint32) + y_div = safe_idx // 32 + y_rem = safe_idx % 32 + bm_iters: tl.constexpr = N_EXPTS_PAD // BLOCK_N # = 1 (single-block) + for i in range(bm_iters): + offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32) + y2 = tl.where( + (y_div[:, :, None] == offs_r_n[None, None, :]) & real_mask[:, :, None], + (1 << y_rem)[:, :, None], + 0, + ) + r = tl.reduce_or(y2, axis=1) + BitsPtrs = ( + Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn + ) + tl.store(BitsPtrs, r, mask=mask_m[:, None]) + + +# --------------------------------------------------------------------------- +# Python wrapper — drop-in for the topk(...) call at routing.py:338-347. +# --------------------------------------------------------------------------- + + +def grouped_topk( + x: torch.Tensor, + k: int, + num_expert_group: int, + topk_group: int, + *, + expert_group: torch.Tensor | None = None, + apply_softmax: bool = False, # accepted for parity with topk(); ignored + HIST_BLOCK_M: int = 32, + score_mode: str = "softmax", + bias: torch.Tensor | None = None, + renorm: bool = False, + routed_scaling_factor: float = 1.0, + num_fused_shared_experts: int = 0, + shared_experts_score: float = 1.0, +): + """Triton grouped top-k expert selection. See module docstring. + + Returns ``(y_vals, y_indx, bitmatrix)`` matching the contract of + ``aiter.ops.triton.moe.moe_routing.topk.topk``: + + - y_vals: ``(n_rows, k + num_fused_shared_experts)`` in ``x.dtype``. + - y_indx: ``(n_rows, k + num_fused_shared_experts)`` ``int16``. + + When ``num_fused_shared_experts > 0`` the routed top-k selection occupies + the first ``k`` columns and the always-on shared expert(s) occupy the next + ``num_fused_shared_experts`` columns — expert id ``n_cols + i``, weight + ``shared_experts_score`` (appended after the routed renorm, mirroring + ``init_aiter_topK_meta_data`` / ``rocm_aiter_grouped_topk``). The bitmatrix + is widened to ``n_cols + num_fused_shared_experts`` columns so ``sort_tokens`` + counts the shared bucket. + + - bitmatrix: real :class:`Bitmatrix`; same uint32 + ``(n_cols_words, n_rows_pad32).T`` storage / scratchpad layout the + ``_topk`` kernel emits, so ``sort_tokens`` and ``sort_tokens_fused`` + consume it unchanged. + """ + assert x.dim() == 2 + n_rows, n_cols = x.shape + assert n_cols <= 256, ( + f"DeepSeek-class envelope: n_expts_tot ({n_cols}) must be <= 256" + ) + # Fused shared experts are appended (always-on) AFTER the routed selection; + # they occupy expert ids [n_cols, n_cols + num_fused_shared_experts). + n_shared = num_fused_shared_experts + assert n_shared >= 0 + n_total = n_cols + n_shared # experts incl. shared (bitmatrix width) + k_out = k + n_shared # output width (routed top-k + shared) + assert num_expert_group > 1 + assert num_expert_group <= 16, ( + f"NUM_EXPERT_GROUP ({num_expert_group}) > 16 not supported" + ) + assert 0 < topk_group <= num_expert_group + assert 0 < k <= 16 + assert score_mode in ("softmax", "sigmoid", "sqrtsoftplus", "none"), ( + f"unknown score_mode {score_mode!r}" + ) + has_bias = bias is not None + if has_bias: + assert bias.dim() == 1 and bias.shape[0] == n_cols + assert bias.dtype == torch.float32 + assert score_mode in ("sqrtsoftplus", "sigmoid"), ( + "bias only supported with sqrtsoftplus / sigmoid" + ) + + dev = x.device + + # Default expert→group mapping = contiguous DeepSeek layout. + if expert_group is None: + assert n_cols % num_expert_group == 0, ( + f"n_expts_tot ({n_cols}) not divisible by num_expert_group " + f"({num_expert_group}); pass an explicit expert_group table." + ) + g_size = n_cols // num_expert_group + expert_group = ( + torch.arange(n_cols, device=dev, dtype=torch.int32) // g_size + ).to(torch.int32) + else: + assert expert_group.dim() == 1 and expert_group.shape[0] == n_cols + assert expert_group.dtype == torch.int32 + + # Block sizes — single BLOCK_N pass for DeepSeek envelope. BLOCK_N must + # cover the shared-expert columns too so their bits fit in the bitmatrix. + BLOCK_M = 1 + BLOCK_N = max(32, triton.next_power_of_2(n_total)) + N_EXPTS_PAD = BLOCK_N + # Mirror topk(): pad to ≥ 2 to dodge tl.argmax/topk(k=1) compile quirks. + N_EXPTS_ACT_PAD = max(2, triton.next_power_of_2(k_out)) + BLOCK_S = 128 + BLOCK_SP = 128 + TILE_SIZE = 8 + + # Outputs (same shapes / dtypes as topk(...)), widened by the shared slots. + y_vals = torch.empty((n_rows, k_out), dtype=x.dtype, device=dev) + y_indx = torch.empty((n_rows, k_out), dtype=torch.int16, device=dev) + + # Bitmatrix in transposed-uint32 storage layout (identical to topk()). + n_cols_pad = triton.cdiv(n_total, BLOCK_N) * BLOCK_N + n_cols_words = n_cols_pad // 32 + bitmatrix_data = torch.empty( + (n_cols_words, triton.cdiv(n_rows, 32) * 32), + dtype=torch.uint32, + device=dev, + ) + bitmatrix_data = torch.transpose(bitmatrix_data, 0, 1)[:n_rows] + + # Scratchpads. The per-column sum buffer consumed by Bitmatrix.sum() / + # sort_tokens must cover the full padded column count (n_cols_pad), which + # widens with the shared experts; sizing by n_total alone can under-allocate + # (e.g. n_total=257 -> n_cols_pad=512 but cdiv(257,128)*128=384). + s_blocks = triton.cdiv(n_cols_pad, BLOCK_S) + s_cols = s_blocks * BLOCK_S + scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev) + BLOCK_MM = HIST_BLOCK_M * TILE_SIZE + pids_x = triton.cdiv(n_rows, BLOCK_MM) + scratchpad_partials = torch.empty( + (n_cols_pad, pids_x * TILE_SIZE), dtype=torch.int32, device=dev + ) + scratchpad_partials = torch.transpose(scratchpad_partials, 0, 1) + sp_size = scratchpad_partials.numel() + sp_blocks = triton.cdiv(sp_size, BLOCK_SP) + + pids = max(triton.cdiv(n_rows, BLOCK_M), s_blocks + sp_blocks) + + _grouped_topk[(pids,)]( + x, + x.stride(0), + expert_group, + y_vals, + y_indx, + y_vals.stride(0), + bitmatrix_data, + bitmatrix_data.stride(0), + bitmatrix_data.stride(1), + n_rows, + n_cols, + scratchpad, + BLOCK_S, + s_blocks, + scratchpad_partials, + BLOCK_SP, + sp_blocks, + sp_size, + BLOCK_M=BLOCK_M, + N_EXPTS_PAD=N_EXPTS_PAD, + BLOCK_N=BLOCK_N, + N_EXPTS_ACT=k, + N_EXPTS_ACT_PAD=N_EXPTS_ACT_PAD, + NUM_EXPERT_GROUP=num_expert_group, + TOPK_GROUP=topk_group, + Bias=bias, + SCORE_MODE=score_mode, + HAS_BIAS=has_bias, + APPLY_RENORM=renorm, + ROUTED_SCALING=routed_scaling_factor, + N_SHARED=n_shared, + SHARED_SCORE=shared_experts_score, + num_warps=4, + ) + + bitmatrix = Bitmatrix( + bitmatrix_data, + shape=[n_rows, n_cols_words * 32], + scratchpad=scratchpad, + scratchpad_partials=scratchpad_partials, + ) + return y_vals, y_indx, bitmatrix diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index c5f1ef3a08..1ff76df03c 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -321,6 +321,12 @@ def routing_a8w4( bias: torch.Tensor | None = None, renorm: bool = True, routed_scaling_factor: float = 1.0, + use_grouped_topk: bool = False, + num_expert_group: int | None = None, + topk_group: int | None = None, + expert_group: torch.Tensor | None = None, + num_fused_shared_experts: int = 0, + shared_experts_score: float = 1.0, ): """All-Triton routing for the a8w4 path: fused V4 routing math + sort. @@ -330,27 +336,72 @@ def routing_a8w4( 2. aiter `sort_tokens` (or `sort_tokens_fused` for tiny M): sort tokens by expert and produce ExptData specialized for the given ``block_m``. + When ``use_grouped_topk=True``, step 1 is replaced by ATOM's single-fused + Triton ``grouped_topk`` kernel + (``atom.model_ops.grouped_topk_triton.grouped_topk``) — DeepSeek-V2/V3-style + hierarchical routing (pick ``topk_group`` groups out of + ``num_expert_group``, then top-``n_expts_act`` experts within those + groups). Same return contract as ``topk`` (y_vals, y_indx, Bitmatrix), so + ``sort_tokens`` / ``sort_tokens_fused`` consume it unchanged. + Returns (RoutingData, gather_indx, scatter_indx) where gather_indx and scatter_indx are raw int32 tensors (no GatherIndx/ScatterIndx wrappers) — consumed directly by ``moe_gemm_a8w4``. No multi-block_m dict, no triton_kernels wrapper, no Python bridge step. """ - from .topk import topk - - n_tokens, n_expts_tot = logits.shape - - # Step 1: extended topk does sqrtsoftplus + bias + topk + bitmatrix + renorm + scale. - expt_scal, expt_indx, bitmatrix = topk( - logits, - n_expts_act, - apply_softmax=False, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=routed_scaling_factor, - HIST_BLOCK_M=32, - ) + n_tokens, n_routed = logits.shape + + # Fused shared experts are appended (always-on) to every token by the + # grouped-topk kernel, occupying expert ids [n_routed, n_routed + n_shared). + # They widen both the per-token selection (n_expts_act) and the total + # expert count used for the sort / histogram. + n_shared = num_fused_shared_experts + n_expts_tot = n_routed + n_shared + + # Step 1: per-token expert selection. Either flat top-k (existing aiter + # _topk kernel) or grouped top-k (ATOM's _grouped_topk kernel) — both + # return (y_vals, y_indx, Bitmatrix) with the same downstream contract. + if use_grouped_topk and num_expert_group != 1: + assert ( + num_expert_group is not None and topk_group is not None + ), "use_grouped_topk requires num_expert_group and topk_group" + # Lazy import: ATOM-side kernel; avoids hard aiter→atom import order. + from aiter.ops.triton.moe.moe_routing.grouped_topk_triton import grouped_topk + + expt_scal, expt_indx, bitmatrix = grouped_topk( + logits, + n_expts_act, + num_expert_group=num_expert_group, + topk_group=topk_group, + expert_group=expert_group, + apply_softmax=False, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=routed_scaling_factor, + num_fused_shared_experts=n_shared, + shared_experts_score=shared_experts_score, + HIST_BLOCK_M=32, + ) + # Routed top-k + appended shared experts per token. + n_expts_act = n_expts_act + n_shared + else: + assert n_shared == 0, ( + "fused shared experts are only supported on the grouped-topk path" + ) + from .topk import topk + + expt_scal, expt_indx, bitmatrix = topk( + logits, + n_expts_act, + apply_softmax=False, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=routed_scaling_factor, + HIST_BLOCK_M=32, + ) # Step 2: sort tokens by expert and build ExptData for the chosen block_m. if n_tokens <= 16: diff --git a/op_tests/triton_tests/moe/test_grouped_topk.py b/op_tests/triton_tests/moe/test_grouped_topk.py new file mode 100644 index 0000000000..ca3767b5bf --- /dev/null +++ b/op_tests/triton_tests/moe/test_grouped_topk.py @@ -0,0 +1,603 @@ +"""Unit tests for ATOM's single-fused Triton grouped-top-k routing kernel +(``atom.model_ops.grouped_topk_triton.grouped_topk``). + +Structured after ``test_moe_routing.py``: + * Reference uses aiter's torch grouped-topk + (``biased_grouped_topk_torch`` / ``grouped_topk_torch``) for the standard + contiguous DeepSeek group layout, plus a thin wrapper for the + ``sqrtsoftplus`` score mode and the ``routed_scaling_factor`` scale that + the aiter refs don't apply. + * ``(y_vals, y_indx)`` are compared per-row set-wise (sorted by expert id), + robust to the kernel returning experts in descending-score order. + * The emitted ``Bitmatrix`` is decoded and checked against the selected + expert set. + * End-to-end ``routing_a8w4(use_grouped_topk=True)`` is validated through the + sort_tokens / ExptData pipeline with a bucket-multiset check. +""" + +import pytest +import torch +import torch.nn.functional as F + +from aiter.ops.triton.utils._triton.arch_info import get_arch +from aiter.ops.topk import biased_grouped_topk_torch, grouped_topk_torch +from aiter.ops.triton.moe.moe_routing.routing import ( + routing_a8w4, + compute_expt_data_torch, +) + +# grouped_topk lives in ATOM; skip the whole module if ATOM isn't importable +# in this environment (e.g. aiter-only CI). +atom_grouped_topk = pytest.importorskip( + "atom.model_ops.grouped_topk_triton" +).grouped_topk + + +# -------------------------------------------------------------------------- +# comparison helpers (copied from test_moe_routing.py for self-containment) +# -------------------------------------------------------------------------- + + +def assert_equal(ref, tri): + if isinstance(ref, torch.Tensor): + assert ((ref.cpu().numpy() - tri.cpu().numpy()) ** 2).sum() == 0 + else: + assert ref == tri + + +def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True): + if maxtol is None: + maxtol = 2e-2 + if rmstol is None: + rmstol = 4e-3 + ref = ref.to(torch.float32).detach() + tri = tri.to(torch.float32).detach() + assert ( + ref.shape == tri.shape + ), f"Tensors must have same size {ref.shape=} {tri.shape=}" + + inf_mask_ref = torch.isinf(ref) + inf_mask_tri = torch.isinf(tri) + assert torch.equal( + inf_mask_ref, inf_mask_tri + ), "Tensor must have same infinite elements" + refn = torch.where(inf_mask_ref, 0, ref) + trin = torch.where(inf_mask_tri, 0, tri) + + eps = 1.0e-30 + multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps) + refn *= multiplier + trin *= multiplier + + ref_rms = torch.sqrt(torch.square(refn).mean()) + eps + rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn)) + max_err = torch.max(rel_err).item() + rms_err = torch.sqrt(torch.square(rel_err).mean()).item() + + if verbose: + print(f"{description} max rel err = {max_err} (thr {maxtol})") + print(f"{description} rms rel err = {rms_err} (thr {rmstol})") + + assert max_err <= maxtol + assert rms_err <= rmstol + + +def init_data(n_tokens, n_expts_tot, dtype=torch.float32, device="cuda"): + return torch.randn((n_tokens, n_expts_tot), dtype=dtype, device=device) + + +# -------------------------------------------------------------------------- +# torch references +# -------------------------------------------------------------------------- + + +def _ref_sqrtsoftplus_grouped( + logits, bias, k, num_expert_group, topk_group, renorm, scale +): + """sqrtsoftplus grouped-topk reference (no aiter equivalent exists). + + Mirrors the kernel: sqrt(softplus(logits)) transform, bias added for + SELECTION only, top-2-sum-per-group when biased else per-group max, mask + non-selected groups, top-k on the (biased) choice scores, gather UNBIASED + weights, renorm + scale. + """ + nt, ne = logits.shape + g_size = ne // num_expert_group + transform = torch.sqrt(F.softplus(logits.float())) + choice = transform + bias.float().unsqueeze(0) if bias is not None else transform + + if bias is not None: + group_scores = ( + choice.view(nt, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + group_scores = choice.view(nt, num_expert_group, -1).max(dim=-1).values + + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False).indices + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1.0) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(nt, num_expert_group, g_size) + .reshape(nt, ne) + .bool() + ) + tmp = choice.masked_fill(~score_mask, float("-inf")) + ids = torch.topk(tmp, k=k, dim=-1, sorted=False).indices + w = transform.gather(1, ids) + if renorm: + w = w / (w.sum(dim=-1, keepdim=True) + 1e-20) + w = w * scale + return w.float(), ids.to(torch.int64) + + +def _ref_contiguous( + logits, k, num_expert_group, topk_group, score_mode, bias, renorm, scale +): + """Reference for contiguous DeepSeek group layout. Reuses aiter torch refs + where they apply, plus the sqrtsoftplus wrapper + scale.""" + if score_mode == "sqrtsoftplus": + return _ref_sqrtsoftplus_grouped( + logits, bias, k, num_expert_group, topk_group, renorm, scale + ) + if score_mode == "sigmoid" and bias is not None: + w, ids = biased_grouped_topk_torch( + logits, bias, k, renorm, num_expert_group, topk_group + ) + elif score_mode in ("sigmoid", "softmax"): + w, ids = grouped_topk_torch( + logits, k, renorm, num_expert_group, topk_group, scoring_func=score_mode + ) + else: + raise ValueError(score_mode) + return w.float() * scale, ids.to(torch.int64) + + +def _ref_arbitrary_grouped( + logits, expert_group, k, num_expert_group, topk_group, score_mode, bias, renorm, scale +): + """General reference honoring an arbitrary expert->group table (equal-size + groups). Used for the non-contiguous mapping case where the aiter refs + (which assume contiguous .view groups) don't apply.""" + nt, ne = logits.shape + f32 = logits.float() + if score_mode == "softmax": + scores = torch.softmax(f32, dim=-1) + elif score_mode == "sigmoid": + scores = f32.sigmoid() + elif score_mode == "sqrtsoftplus": + scores = torch.sqrt(F.softplus(f32)) + else: + scores = f32 + choice = scores + bias.float().unsqueeze(0) if bias is not None else scores + + group_scores = torch.empty((nt, num_expert_group), device=logits.device) + for g in range(num_expert_group): + cols = (expert_group == g).nonzero(as_tuple=False).flatten() + sub = choice[:, cols] + if bias is not None: + group_scores[:, g] = sub.topk(2, dim=-1)[0].sum(dim=-1) + else: + group_scores[:, g] = sub.max(dim=-1).values + + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False).indices + group_sel = torch.zeros((nt, num_expert_group), device=logits.device, dtype=torch.bool) + group_sel.scatter_(1, group_idx, True) + # expert keep mask via group table lookup + expert_keep = group_sel[:, expert_group.long()] # (nt, ne) + + tmp = choice.masked_fill(~expert_keep, float("-inf")) + ids = torch.topk(tmp, k=k, dim=-1, sorted=False).indices + w = scores.gather(1, ids) + if renorm: + w = w / (w.sum(dim=-1, keepdim=True) + 1e-20) + w = w * scale + return w.float(), ids.to(torch.int64) + + +# -------------------------------------------------------------------------- +# output comparison utilities +# -------------------------------------------------------------------------- + + +def _row_sort_by_id(ids, weights): + order = torch.argsort(ids, dim=1) + return torch.gather(ids, 1, order), torch.gather(weights, 1, order) + + +def _assert_selection_matches(ref_ids, ref_w, tri_ids, tri_w): + """Set-wise per-row comparison: sort both by expert id, then assert ids + identical and gathered weights close.""" + ref_ids_s, ref_w_s = _row_sort_by_id(ref_ids.cpu(), ref_w.cpu()) + tri_ids_s, tri_w_s = _row_sort_by_id(tri_ids.cpu().long(), tri_w.cpu().float()) + assert torch.equal(ref_ids_s, tri_ids_s), ( + f"selected expert ids differ:\nref={ref_ids_s}\ntri={tri_ids_s}" + ) + assert_close(ref_w_s, tri_w_s, 2e-2, 4e-3, description="weights") + + +def _decode_bitmatrix(bitmatrix, n_tokens, n_expts_tot): + """Decode the packed uint32 Bitmatrix into a (n_tokens, n_expts_tot) bool + matrix of selected experts.""" + data = bitmatrix.data[:n_tokens].to(torch.int64) # (n_tokens, n_cols_words) + n_cols_words = data.shape[1] + bits = torch.arange(32, device=data.device, dtype=torch.int64) + unpacked = ((data.unsqueeze(-1) >> bits) & 1).bool() # (nt, words, 32) + unpacked = unpacked.reshape(n_tokens, n_cols_words * 32) + return unpacked[:, :n_expts_tot] + + +def _assert_bitmatrix_matches(bitmatrix, tri_ids, n_tokens, n_expts_tot): + decoded = _decode_bitmatrix(bitmatrix, n_tokens, n_expts_tot).cpu() + expected = torch.zeros((n_tokens, n_expts_tot), dtype=torch.bool) + expected.scatter_(1, tri_ids.cpu().long(), True) + assert torch.equal(decoded, expected), "bitmatrix does not match selected ids" + + +# -------------------------------------------------------------------------- +# end-to-end routing helpers (mirror of test_moe_routing.py, compacted) +# -------------------------------------------------------------------------- + + +def _sort_and_build_torch(expt_scal, expt_indx, n_expts_tot, block_m): + n_tokens, n_expts_act = expt_scal.shape + n_gates = n_tokens * n_expts_act + scal_flat = expt_scal.reshape(-1) + indx_flat = expt_indx.reshape(-1).to(torch.int32) + topk_indx = torch.argsort(indx_flat, stable=True).to(torch.int32) + gate_indx = torch.argsort(topk_indx, stable=True).to(torch.int32) + gate_scal = scal_flat[topk_indx.long()] + hist = torch.histc( + indx_flat.float(), bins=n_expts_tot, min=0, max=n_expts_tot - 1 + ).int() + expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates, block_m) + return hist, topk_indx, gate_indx, gate_scal, expt_data + + +def _check_routing_data_bucket( + ref_pack, tri_routing_data, tri_gather, tri_scatter, topk_weights, topk_ids +): + ref_hist, _, _, _, ref_expt_data = ref_pack + assert_equal(ref_hist, tri_routing_data.expt_hist) + assert_equal(ref_expt_data.hist, tri_routing_data.expt_data.hist) + assert_equal(ref_expt_data.token_offs_raw, tri_routing_data.expt_data.token_offs_raw) + assert_equal(ref_expt_data.token_offs_pad, tri_routing_data.expt_data.token_offs_pad) + assert_equal(ref_expt_data.block_pid_map, tri_routing_data.expt_data.block_pid_map) + + n_tokens, n_expts_act = topk_ids.shape + n_gates = n_tokens * n_expts_act + n_expts_tot = ref_hist.numel() + + iota = torch.arange(n_gates, dtype=torch.int32, device=tri_gather.device) + assert torch.equal(tri_scatter[tri_gather.long()], iota), "scatter[gather[j]] != j" + + flat_ids = topk_ids.reshape(-1).cpu().tolist() + flat_w = topk_weights.reshape(-1).float().cpu().tolist() + src = tri_gather.cpu().tolist() + scal = tri_routing_data.gate_scal.float().cpu().tolist() + cum = torch.cumsum(ref_hist, dim=0).cpu().tolist() + + ground = {e: [] for e in range(n_expts_tot)} + for i, e in enumerate(flat_ids): + ground[e].append((i // n_expts_act, flat_w[i])) + for e in ground: + ground[e].sort() + + got = {e: [] for e in range(n_expts_tot)} + e = 0 + for j in range(n_gates): + while e < n_expts_tot and j >= cum[e]: + e += 1 + assert flat_ids[src[j]] == e, f"bucket-invariant violated at pos {j}" + got[e].append((src[j] // n_expts_act, scal[j])) + for e in got: + got[e].sort() + + for e in range(n_expts_tot): + rb, tb = ground[e], got[e] + assert len(rb) == len(tb), f"expert {e}: ref={len(rb)} test={len(tb)}" + for (tt_r, w_r), (tt_t, w_t) in zip(rb, tb): + assert tt_r == tt_t, f"expert {e}: token ref={tt_r} test={tt_t}" + assert abs(w_r - w_t) <= 1e-6, f"expert {e} token {tt_r}: w {w_r} vs {w_t}" + + +# -------------------------------------------------------------------------- +# parametrization +# -------------------------------------------------------------------------- + +# (n_expts_tot, num_expert_group, topk_group, n_expts_act) — DeepSeek-like. +GROUP_SHAPES = [ + (256, 8, 4, 8), + (128, 8, 4, 6), +] +# n_tokens spanning the fused (<=16) and regular sort_tokens paths. +N_TOKENS = [8, 16, 64, 1024] +# (score_mode, has_bias, renorm, routed_scaling_factor) — production-core set. +SCORE_COMBOS = [ + ("sqrtsoftplus", True, True, 2.5), + ("sigmoid", True, True, 1.0), + ("softmax", False, False, 1.0), +] + + +def _maybe_skip(): + if not torch.cuda.is_available(): + pytest.skip("grouped_topk requires a GPU") + if get_arch() not in ["gfx950", "gfx1250"]: + pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") + + +# -------------------------------------------------------------------------- +# 1. direct kernel test: (y_vals, y_indx, bitmatrix) +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_tokens", N_TOKENS) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +@pytest.mark.parametrize("score_mode, has_bias, renorm, scale", SCORE_COMBOS) +def test_grouped_topk_kernel( + n_tokens, + n_expts_tot, + num_expert_group, + topk_group, + n_expts_act, + score_mode, + has_bias, + renorm, + scale, +): + _maybe_skip() + device = "cuda" + torch.manual_seed(2) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = ( + torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + if has_bias + else None + ) + + ref_w, ref_ids = _ref_contiguous( + logits.clone(), + n_expts_act, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, + ) + y_vals, y_indx, bitmatrix = atom_grouped_topk( + logits, + n_expts_act, + num_expert_group=num_expert_group, + topk_group=topk_group, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + ) + + assert y_vals.shape == (n_tokens, n_expts_act) + assert y_indx.shape == (n_tokens, n_expts_act) + assert y_indx.dtype == torch.int16 + assert y_vals.dtype == logits.dtype + + _assert_selection_matches(ref_ids, ref_w, y_indx, y_vals) + _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot) + + +# -------------------------------------------------------------------------- +# 2. arbitrary (non-contiguous) expert->group mapping +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_tokens", [8, 64, 1024]) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +def test_grouped_topk_arbitrary_group( + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act +): + _maybe_skip() + device = "cuda" + torch.manual_seed(7) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + + # Equal-size groups but a shuffled (non-contiguous) expert->group table. + g_size = n_expts_tot // num_expert_group + perm = torch.randperm(n_expts_tot, device=device) + expert_group = torch.empty(n_expts_tot, dtype=torch.int32, device=device) + for g in range(num_expert_group): + expert_group[perm[g * g_size : (g + 1) * g_size]] = g + + score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 + ref_w, ref_ids = _ref_arbitrary_grouped( + logits.clone(), + expert_group, + n_expts_act, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, + ) + y_vals, y_indx, bitmatrix = atom_grouped_topk( + logits, + n_expts_act, + num_expert_group=num_expert_group, + topk_group=topk_group, + expert_group=expert_group, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + ) + + _assert_selection_matches(ref_ids, ref_w, y_indx, y_vals) + _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot) + + +# -------------------------------------------------------------------------- +# 3. end-to-end routing_a8w4(use_grouped_topk=True) +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_tokens", [8, 16, 64, 1024]) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +@pytest.mark.parametrize("block_m", [16, 32]) +def test_routing_a8w4_grouped( + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, block_m +): + _maybe_skip() + device = "cuda" + torch.manual_seed(2) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 + + # The selection the kernel makes (deterministic for fixed inputs); used as + # ground truth for the sort/scatter pipeline check. + y_vals, y_indx, _ = atom_grouped_topk( + logits, + n_expts_act, + num_expert_group=num_expert_group, + topk_group=topk_group, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + ) + + tri_routing_data, tri_gather, tri_scatter = routing_a8w4( + logits, + n_expts_act, + block_m, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + use_grouped_topk=True, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + + ref_pack = _sort_and_build_torch( + y_vals.float(), y_indx.to(torch.int32), n_expts_tot, block_m + ) + _check_routing_data_bucket( + ref_pack, tri_routing_data, tri_gather, tri_scatter, y_vals.float(), y_indx + ) + assert tri_routing_data.n_expts_tot == n_expts_tot + assert tri_routing_data.n_expts_act == n_expts_act + assert tri_routing_data.block_m == block_m + + +# -------------------------------------------------------------------------- +# 4. fused shared experts (DeepSeek-R1/V3 always-on shared expert) +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_tokens", [8, 64, 1024]) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +@pytest.mark.parametrize("n_shared", [1, 2]) +def test_grouped_topk_shared_expert( + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, n_shared +): + """The kernel appends `n_shared` always-on shared experts (id n_expts_tot+i, + weight 1.0) AFTER the routed renorm. The routed portion must still match the + reference, and the shared columns + bitmatrix must reflect the append.""" + _maybe_skip() + device = "cuda" + torch.manual_seed(2) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 + + ref_w, ref_ids = _ref_contiguous( + logits.clone(), n_expts_act, num_expert_group, topk_group, + score_mode, bias, renorm, scale, + ) + y_vals, y_indx, bitmatrix = atom_grouped_topk( + logits, + n_expts_act, + num_expert_group=num_expert_group, + topk_group=topk_group, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + num_fused_shared_experts=n_shared, + shared_experts_score=1.0, + ) + + assert y_vals.shape == (n_tokens, n_expts_act + n_shared) + assert y_indx.shape == (n_tokens, n_expts_act + n_shared) + + # Routed slots (first n_expts_act) must match the reference selection. + _assert_selection_matches( + ref_ids, ref_w, y_indx[:, :n_expts_act], y_vals[:, :n_expts_act] + ) + + # Shared slots: fixed id n_expts_tot+i, weight 1.0, for every token. + for i in range(n_shared): + ids_i = y_indx[:, n_expts_act + i].cpu().long() + w_i = y_vals[:, n_expts_act + i].float().cpu() + assert torch.all(ids_i == n_expts_tot + i), f"shared id col {i}: {ids_i}" + assert torch.allclose(w_i, torch.ones(n_tokens)), f"shared weight col {i}" + + # Bitmatrix must contain routed + shared selections over the widened width. + _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot + n_shared) + + +@pytest.mark.parametrize("n_tokens", [8, 16, 64, 1024]) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +@pytest.mark.parametrize("block_m", [16, 32]) +@pytest.mark.parametrize("n_shared", [1, 2]) +def test_routing_a8w4_grouped_shared( + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, block_m, n_shared +): + """End-to-end routing_a8w4 with fused shared experts: histogram must include + a full shared bucket (n_tokens) per shared expert and the gather/scatter must + form a valid inverse permutation over the widened gate count.""" + _maybe_skip() + device = "cuda" + torch.manual_seed(2) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 + + rd, gather, scatter = routing_a8w4( + logits, + n_expts_act, + block_m, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + use_grouped_topk=True, + num_expert_group=num_expert_group, + topk_group=topk_group, + num_fused_shared_experts=n_shared, + ) + + assert rd.n_expts_tot == n_expts_tot + n_shared + assert rd.n_expts_act == n_expts_act + n_shared + + # Every token is routed to each shared expert exactly once. + for i in range(n_shared): + assert rd.expt_hist[n_expts_tot + i].item() == n_tokens + assert rd.expt_hist.sum().item() == n_tokens * (n_expts_act + n_shared) + + n_gates = n_tokens * (n_expts_act + n_shared) + iota = torch.arange(n_gates, dtype=torch.int32, device=gather.device) + assert torch.equal(scatter[gather.long()], iota), "scatter[gather[j]] != j" From 8d679de0bc1ab5010834420320ca569444669c09 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Tue, 2 Jun 2026 17:35:13 +0000 Subject: [PATCH 02/18] cleanup after rebase --- aiter/ops/triton/moe/moe_op_gemm_a8w4.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py index bda2efc0c7..96d742bfcb 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py @@ -5,6 +5,7 @@ import itertools import os import json +import json import torch import triton from aiter.ops.triton.moe.moe_routing.routing import RoutingData From ef709192314249bd3fe3ec8f6e9aba539e73bc10 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Tue, 2 Jun 2026 18:16:52 +0000 Subject: [PATCH 03/18] restructured grouped topk kernel --- .../moe/moe_routing/grouped_topk.py} | 216 ----------------- .../triton/moe/moe_routing/grouped_topk.py | 220 ++++++++++++++++++ aiter/ops/triton/moe/moe_routing/routing.py | 4 +- 3 files changed, 222 insertions(+), 218 deletions(-) rename aiter/ops/triton/{moe/moe_routing/grouped_topk_triton.py => _triton_kernels/moe/moe_routing/grouped_topk.py} (52%) create mode 100644 aiter/ops/triton/moe/moe_routing/grouped_topk.py diff --git a/aiter/ops/triton/moe/moe_routing/grouped_topk_triton.py b/aiter/ops/triton/_triton_kernels/moe/moe_routing/grouped_topk.py similarity index 52% rename from aiter/ops/triton/moe/moe_routing/grouped_topk_triton.py rename to aiter/ops/triton/_triton_kernels/moe/moe_routing/grouped_topk.py index 297793584e..daa0c1e3dc 100644 --- a/aiter/ops/triton/moe/moe_routing/grouped_topk_triton.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_routing/grouped_topk.py @@ -1,46 +1,3 @@ -# SPDX-License-Identifier: MIT -"""Single-fused Triton grouped-top-k routing kernel. - -Drop-in replacement for the ``topk(...)`` call inside -``aiter/ops/triton/moe/moe_routing/routing.py::routing_a8w4`` (lines 338-347). -Same return contract — ``(y_vals, y_indx, Bitmatrix)`` — so downstream -``sort_tokens`` / ``sort_tokens_fused`` consume the output unchanged. - -Algorithm (single kernel launch, mirrors the structure of aiter's ``_topk`` -and ``_hash_routing`` in ``_triton_kernels/moe/moe_routing/topk.py``): - - 1. Memset bitmatrix scratchpad / partials (same lane-borrowing trick as - ``_topk``: the first ``s_blocks + sp_blocks`` programs do nothing but - zero-fill). - 2. Load the row of router logits. - 3. Apply ``score_mode`` per element ('softmax' / 'sigmoid' / 'sqrtsoftplus' / - 'none'). - 4. Per-group score reduction over an *arbitrary* expert→group mapping - (``ExpertGroup`` int32 table): - - HAS_BIAS → top-2 sum on bias-augmented scores (DeepSeek-V3 rule; - mirrors ``biased_grouped_topk_torch``). - - else → per-group max (DeepSeek-V2 rule; mirrors - ``grouped_topk_torch``). - 5. Pick top ``TOPK_GROUP`` groups via repeated argmax (NUM_EXPERT_GROUP is - small, so the unrolled loop is tiny). - 6. Mask experts in non-selected groups to ``-inf`` on the bias-augmented - scores, then do per-expert top-``N_EXPTS_ACT`` via repeated argmax. - 7. Gather *unbiased* weights at the selected indices (matches the - ``noaux_tc`` semantics — bias used for selection only, weights from the - untouched score). - 8. Optional renorm + ``routed_scaling_factor`` scale. - 9. Pack selected indices into the (n_cols_words, n_rows_pad32).T uint32 - bitmatrix layout the kernel emits, identical to ``_topk``. - -Constraints (DeepSeek-class envelope): - - n_expts_tot ≤ 256 (single ``BLOCK_N`` pass; no streaming loop). - - num_expert_group ≤ 16. - - topk_group ≤ num_expert_group. - - n_expts_act (top_k) ≤ 16. - - BLOCK_M = 1 (the per-group 3-D intermediate is - ``[BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP]`` fp32 — at BLOCK_M=1 that's - ≤ 256 * 16 * 4 = 16 KiB, fits in registers / LDS comfortably). -""" from __future__ import annotations import torch @@ -276,176 +233,3 @@ def _grouped_topk( ) tl.store(BitsPtrs, r, mask=mask_m[:, None]) - -# --------------------------------------------------------------------------- -# Python wrapper — drop-in for the topk(...) call at routing.py:338-347. -# --------------------------------------------------------------------------- - - -def grouped_topk( - x: torch.Tensor, - k: int, - num_expert_group: int, - topk_group: int, - *, - expert_group: torch.Tensor | None = None, - apply_softmax: bool = False, # accepted for parity with topk(); ignored - HIST_BLOCK_M: int = 32, - score_mode: str = "softmax", - bias: torch.Tensor | None = None, - renorm: bool = False, - routed_scaling_factor: float = 1.0, - num_fused_shared_experts: int = 0, - shared_experts_score: float = 1.0, -): - """Triton grouped top-k expert selection. See module docstring. - - Returns ``(y_vals, y_indx, bitmatrix)`` matching the contract of - ``aiter.ops.triton.moe.moe_routing.topk.topk``: - - - y_vals: ``(n_rows, k + num_fused_shared_experts)`` in ``x.dtype``. - - y_indx: ``(n_rows, k + num_fused_shared_experts)`` ``int16``. - - When ``num_fused_shared_experts > 0`` the routed top-k selection occupies - the first ``k`` columns and the always-on shared expert(s) occupy the next - ``num_fused_shared_experts`` columns — expert id ``n_cols + i``, weight - ``shared_experts_score`` (appended after the routed renorm, mirroring - ``init_aiter_topK_meta_data`` / ``rocm_aiter_grouped_topk``). The bitmatrix - is widened to ``n_cols + num_fused_shared_experts`` columns so ``sort_tokens`` - counts the shared bucket. - - - bitmatrix: real :class:`Bitmatrix`; same uint32 - ``(n_cols_words, n_rows_pad32).T`` storage / scratchpad layout the - ``_topk`` kernel emits, so ``sort_tokens`` and ``sort_tokens_fused`` - consume it unchanged. - """ - assert x.dim() == 2 - n_rows, n_cols = x.shape - assert n_cols <= 256, ( - f"DeepSeek-class envelope: n_expts_tot ({n_cols}) must be <= 256" - ) - # Fused shared experts are appended (always-on) AFTER the routed selection; - # they occupy expert ids [n_cols, n_cols + num_fused_shared_experts). - n_shared = num_fused_shared_experts - assert n_shared >= 0 - n_total = n_cols + n_shared # experts incl. shared (bitmatrix width) - k_out = k + n_shared # output width (routed top-k + shared) - assert num_expert_group > 1 - assert num_expert_group <= 16, ( - f"NUM_EXPERT_GROUP ({num_expert_group}) > 16 not supported" - ) - assert 0 < topk_group <= num_expert_group - assert 0 < k <= 16 - assert score_mode in ("softmax", "sigmoid", "sqrtsoftplus", "none"), ( - f"unknown score_mode {score_mode!r}" - ) - has_bias = bias is not None - if has_bias: - assert bias.dim() == 1 and bias.shape[0] == n_cols - assert bias.dtype == torch.float32 - assert score_mode in ("sqrtsoftplus", "sigmoid"), ( - "bias only supported with sqrtsoftplus / sigmoid" - ) - - dev = x.device - - # Default expert→group mapping = contiguous DeepSeek layout. - if expert_group is None: - assert n_cols % num_expert_group == 0, ( - f"n_expts_tot ({n_cols}) not divisible by num_expert_group " - f"({num_expert_group}); pass an explicit expert_group table." - ) - g_size = n_cols // num_expert_group - expert_group = ( - torch.arange(n_cols, device=dev, dtype=torch.int32) // g_size - ).to(torch.int32) - else: - assert expert_group.dim() == 1 and expert_group.shape[0] == n_cols - assert expert_group.dtype == torch.int32 - - # Block sizes — single BLOCK_N pass for DeepSeek envelope. BLOCK_N must - # cover the shared-expert columns too so their bits fit in the bitmatrix. - BLOCK_M = 1 - BLOCK_N = max(32, triton.next_power_of_2(n_total)) - N_EXPTS_PAD = BLOCK_N - # Mirror topk(): pad to ≥ 2 to dodge tl.argmax/topk(k=1) compile quirks. - N_EXPTS_ACT_PAD = max(2, triton.next_power_of_2(k_out)) - BLOCK_S = 128 - BLOCK_SP = 128 - TILE_SIZE = 8 - - # Outputs (same shapes / dtypes as topk(...)), widened by the shared slots. - y_vals = torch.empty((n_rows, k_out), dtype=x.dtype, device=dev) - y_indx = torch.empty((n_rows, k_out), dtype=torch.int16, device=dev) - - # Bitmatrix in transposed-uint32 storage layout (identical to topk()). - n_cols_pad = triton.cdiv(n_total, BLOCK_N) * BLOCK_N - n_cols_words = n_cols_pad // 32 - bitmatrix_data = torch.empty( - (n_cols_words, triton.cdiv(n_rows, 32) * 32), - dtype=torch.uint32, - device=dev, - ) - bitmatrix_data = torch.transpose(bitmatrix_data, 0, 1)[:n_rows] - - # Scratchpads. The per-column sum buffer consumed by Bitmatrix.sum() / - # sort_tokens must cover the full padded column count (n_cols_pad), which - # widens with the shared experts; sizing by n_total alone can under-allocate - # (e.g. n_total=257 -> n_cols_pad=512 but cdiv(257,128)*128=384). - s_blocks = triton.cdiv(n_cols_pad, BLOCK_S) - s_cols = s_blocks * BLOCK_S - scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev) - BLOCK_MM = HIST_BLOCK_M * TILE_SIZE - pids_x = triton.cdiv(n_rows, BLOCK_MM) - scratchpad_partials = torch.empty( - (n_cols_pad, pids_x * TILE_SIZE), dtype=torch.int32, device=dev - ) - scratchpad_partials = torch.transpose(scratchpad_partials, 0, 1) - sp_size = scratchpad_partials.numel() - sp_blocks = triton.cdiv(sp_size, BLOCK_SP) - - pids = max(triton.cdiv(n_rows, BLOCK_M), s_blocks + sp_blocks) - - _grouped_topk[(pids,)]( - x, - x.stride(0), - expert_group, - y_vals, - y_indx, - y_vals.stride(0), - bitmatrix_data, - bitmatrix_data.stride(0), - bitmatrix_data.stride(1), - n_rows, - n_cols, - scratchpad, - BLOCK_S, - s_blocks, - scratchpad_partials, - BLOCK_SP, - sp_blocks, - sp_size, - BLOCK_M=BLOCK_M, - N_EXPTS_PAD=N_EXPTS_PAD, - BLOCK_N=BLOCK_N, - N_EXPTS_ACT=k, - N_EXPTS_ACT_PAD=N_EXPTS_ACT_PAD, - NUM_EXPERT_GROUP=num_expert_group, - TOPK_GROUP=topk_group, - Bias=bias, - SCORE_MODE=score_mode, - HAS_BIAS=has_bias, - APPLY_RENORM=renorm, - ROUTED_SCALING=routed_scaling_factor, - N_SHARED=n_shared, - SHARED_SCORE=shared_experts_score, - num_warps=4, - ) - - bitmatrix = Bitmatrix( - bitmatrix_data, - shape=[n_rows, n_cols_words * 32], - scratchpad=scratchpad, - scratchpad_partials=scratchpad_partials, - ) - return y_vals, y_indx, bitmatrix diff --git a/aiter/ops/triton/moe/moe_routing/grouped_topk.py b/aiter/ops/triton/moe/moe_routing/grouped_topk.py new file mode 100644 index 0000000000..2c72039ad7 --- /dev/null +++ b/aiter/ops/triton/moe/moe_routing/grouped_topk.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +from aiter.ops.triton.moe.moe_routing.bitmatrix import Bitmatrix +from aiter.ops.triton._triton_kernels.moe.moe_routing.grouped_topk import _grouped_topk + +# SPDX-License-Identifier: MIT +"""Single-fused Triton grouped-top-k routing kernel. + +Drop-in replacement for the ``topk(...)`` call inside +``aiter/ops/triton/moe/moe_routing/routing.py::routing_a8w4`` (lines 338-347). +Same return contract — ``(y_vals, y_indx, Bitmatrix)`` — so downstream +``sort_tokens`` / ``sort_tokens_fused`` consume the output unchanged. + +Algorithm (single kernel launch, mirrors the structure of aiter's ``_topk`` +and ``_hash_routing`` in ``_triton_kernels/moe/moe_routing/topk.py``): + + 1. Memset bitmatrix scratchpad / partials (same lane-borrowing trick as + ``_topk``: the first ``s_blocks + sp_blocks`` programs do nothing but + zero-fill). + 2. Load the row of router logits. + 3. Apply ``score_mode`` per element ('softmax' / 'sigmoid' / 'sqrtsoftplus' / + 'none'). + 4. Per-group score reduction over an *arbitrary* expert→group mapping + (``ExpertGroup`` int32 table): + - HAS_BIAS → top-2 sum on bias-augmented scores (DeepSeek-V3 rule; + mirrors ``biased_grouped_topk_torch``). + - else → per-group max (DeepSeek-V2 rule; mirrors + ``grouped_topk_torch``). + 5. Pick top ``TOPK_GROUP`` groups via repeated argmax (NUM_EXPERT_GROUP is + small, so the unrolled loop is tiny). + 6. Mask experts in non-selected groups to ``-inf`` on the bias-augmented + scores, then do per-expert top-``N_EXPTS_ACT`` via repeated argmax. + 7. Gather *unbiased* weights at the selected indices (matches the + ``noaux_tc`` semantics — bias used for selection only, weights from the + untouched score). + 8. Optional renorm + ``routed_scaling_factor`` scale. + 9. Pack selected indices into the (n_cols_words, n_rows_pad32).T uint32 + bitmatrix layout the kernel emits, identical to ``_topk``. + +Constraints (DeepSeek-class envelope): + - n_expts_tot ≤ 256 (single ``BLOCK_N`` pass; no streaming loop). + - num_expert_group ≤ 16. + - topk_group ≤ num_expert_group. + - n_expts_act (top_k) ≤ 16. + - BLOCK_M = 1 (the per-group 3-D intermediate is + ``[BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP]`` fp32 — at BLOCK_M=1 that's + ≤ 256 * 16 * 4 = 16 KiB, fits in registers / LDS comfortably). +""" + +def grouped_topk( + x: torch.Tensor, + k: int, + num_expert_group: int, + topk_group: int, + *, + expert_group: torch.Tensor | None = None, + apply_softmax: bool = False, # accepted for parity with topk(); ignored + HIST_BLOCK_M: int = 32, + score_mode: str = "softmax", + bias: torch.Tensor | None = None, + renorm: bool = False, + routed_scaling_factor: float = 1.0, + num_fused_shared_experts: int = 0, + shared_experts_score: float = 1.0, +): + """Triton grouped top-k expert selection. See module docstring. + + Returns ``(y_vals, y_indx, bitmatrix)`` matching the contract of + ``aiter.ops.triton.moe.moe_routing.topk.topk``: + + - y_vals: ``(n_rows, k + num_fused_shared_experts)`` in ``x.dtype``. + - y_indx: ``(n_rows, k + num_fused_shared_experts)`` ``int16``. + + When ``num_fused_shared_experts > 0`` the routed top-k selection occupies + the first ``k`` columns and the always-on shared expert(s) occupy the next + ``num_fused_shared_experts`` columns — expert id ``n_cols + i``, weight + ``shared_experts_score`` (appended after the routed renorm, mirroring + ``init_aiter_topK_meta_data`` / ``rocm_aiter_grouped_topk``). The bitmatrix + is widened to ``n_cols + num_fused_shared_experts`` columns so ``sort_tokens`` + counts the shared bucket. + + - bitmatrix: real :class:`Bitmatrix`; same uint32 + ``(n_cols_words, n_rows_pad32).T`` storage / scratchpad layout the + ``_topk`` kernel emits, so ``sort_tokens`` and ``sort_tokens_fused`` + consume it unchanged. + """ + assert x.dim() == 2 + n_rows, n_cols = x.shape + assert n_cols <= 256, ( + f"DeepSeek-class envelope: n_expts_tot ({n_cols}) must be <= 256" + ) + # Fused shared experts are appended (always-on) AFTER the routed selection; + # they occupy expert ids [n_cols, n_cols + num_fused_shared_experts). + n_shared = num_fused_shared_experts + assert n_shared >= 0 + n_total = n_cols + n_shared # experts incl. shared (bitmatrix width) + k_out = k + n_shared # output width (routed top-k + shared) + assert num_expert_group > 1 + assert num_expert_group <= 16, ( + f"NUM_EXPERT_GROUP ({num_expert_group}) > 16 not supported" + ) + assert 0 < topk_group <= num_expert_group + assert 0 < k <= 16 + assert score_mode in ("softmax", "sigmoid", "sqrtsoftplus", "none"), ( + f"unknown score_mode {score_mode!r}" + ) + has_bias = bias is not None + if has_bias: + assert bias.dim() == 1 and bias.shape[0] == n_cols + assert bias.dtype == torch.float32 + assert score_mode in ("sqrtsoftplus", "sigmoid"), ( + "bias only supported with sqrtsoftplus / sigmoid" + ) + + dev = x.device + + # Default expert→group mapping = contiguous DeepSeek layout. + if expert_group is None: + assert n_cols % num_expert_group == 0, ( + f"n_expts_tot ({n_cols}) not divisible by num_expert_group " + f"({num_expert_group}); pass an explicit expert_group table." + ) + g_size = n_cols // num_expert_group + expert_group = ( + torch.arange(n_cols, device=dev, dtype=torch.int32) // g_size + ).to(torch.int32) + else: + assert expert_group.dim() == 1 and expert_group.shape[0] == n_cols + assert expert_group.dtype == torch.int32 + + # Block sizes — single BLOCK_N pass for DeepSeek envelope. BLOCK_N must + # cover the shared-expert columns too so their bits fit in the bitmatrix. + BLOCK_M = 1 + BLOCK_N = max(32, triton.next_power_of_2(n_total)) + N_EXPTS_PAD = BLOCK_N + # Mirror topk(): pad to ≥ 2 to dodge tl.argmax/topk(k=1) compile quirks. + N_EXPTS_ACT_PAD = max(2, triton.next_power_of_2(k_out)) + BLOCK_S = 128 + BLOCK_SP = 128 + TILE_SIZE = 8 + + # Outputs (same shapes / dtypes as topk(...)), widened by the shared slots. + y_vals = torch.empty((n_rows, k_out), dtype=x.dtype, device=dev) + y_indx = torch.empty((n_rows, k_out), dtype=torch.int16, device=dev) + + # Bitmatrix in transposed-uint32 storage layout (identical to topk()). + n_cols_pad = triton.cdiv(n_total, BLOCK_N) * BLOCK_N + n_cols_words = n_cols_pad // 32 + bitmatrix_data = torch.empty( + (n_cols_words, triton.cdiv(n_rows, 32) * 32), + dtype=torch.uint32, + device=dev, + ) + bitmatrix_data = torch.transpose(bitmatrix_data, 0, 1)[:n_rows] + + # Scratchpads. The per-column sum buffer consumed by Bitmatrix.sum() / + # sort_tokens must cover the full padded column count (n_cols_pad), which + # widens with the shared experts; sizing by n_total alone can under-allocate + # (e.g. n_total=257 -> n_cols_pad=512 but cdiv(257,128)*128=384). + s_blocks = triton.cdiv(n_cols_pad, BLOCK_S) + s_cols = s_blocks * BLOCK_S + scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev) + BLOCK_MM = HIST_BLOCK_M * TILE_SIZE + pids_x = triton.cdiv(n_rows, BLOCK_MM) + scratchpad_partials = torch.empty( + (n_cols_pad, pids_x * TILE_SIZE), dtype=torch.int32, device=dev + ) + scratchpad_partials = torch.transpose(scratchpad_partials, 0, 1) + sp_size = scratchpad_partials.numel() + sp_blocks = triton.cdiv(sp_size, BLOCK_SP) + + pids = max(triton.cdiv(n_rows, BLOCK_M), s_blocks + sp_blocks) + + _grouped_topk[(pids,)]( + x, + x.stride(0), + expert_group, + y_vals, + y_indx, + y_vals.stride(0), + bitmatrix_data, + bitmatrix_data.stride(0), + bitmatrix_data.stride(1), + n_rows, + n_cols, + scratchpad, + BLOCK_S, + s_blocks, + scratchpad_partials, + BLOCK_SP, + sp_blocks, + sp_size, + BLOCK_M=BLOCK_M, + N_EXPTS_PAD=N_EXPTS_PAD, + BLOCK_N=BLOCK_N, + N_EXPTS_ACT=k, + N_EXPTS_ACT_PAD=N_EXPTS_ACT_PAD, + NUM_EXPERT_GROUP=num_expert_group, + TOPK_GROUP=topk_group, + Bias=bias, + SCORE_MODE=score_mode, + HAS_BIAS=has_bias, + APPLY_RENORM=renorm, + ROUTED_SCALING=routed_scaling_factor, + N_SHARED=n_shared, + SHARED_SCORE=shared_experts_score, + num_warps=4, + ) + + bitmatrix = Bitmatrix( + bitmatrix_data, + shape=[n_rows, n_cols_words * 32], + scratchpad=scratchpad, + scratchpad_partials=scratchpad_partials, + ) + return y_vals, y_indx, bitmatrix diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index 1ff76df03c..d8f2fef926 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -338,7 +338,7 @@ def routing_a8w4( When ``use_grouped_topk=True``, step 1 is replaced by ATOM's single-fused Triton ``grouped_topk`` kernel - (``atom.model_ops.grouped_topk_triton.grouped_topk``) — DeepSeek-V2/V3-style + (``atom.model_ops.grouped_topk.grouped_topk``) — DeepSeek-V2/V3-style hierarchical routing (pick ``topk_group`` groups out of ``num_expert_group``, then top-``n_expts_act`` experts within those groups). Same return contract as ``topk`` (y_vals, y_indx, Bitmatrix), so @@ -367,7 +367,7 @@ def routing_a8w4( num_expert_group is not None and topk_group is not None ), "use_grouped_topk requires num_expert_group and topk_group" # Lazy import: ATOM-side kernel; avoids hard aiter→atom import order. - from aiter.ops.triton.moe.moe_routing.grouped_topk_triton import grouped_topk + from aiter.ops.triton.moe.moe_routing.grouped_topk import grouped_topk expt_scal, expt_indx, bitmatrix = grouped_topk( logits, From 980ef5b4836b9feade96b143e3692ab826dd6986 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Tue, 2 Jun 2026 18:28:08 +0000 Subject: [PATCH 04/18] moved grouped_topk to regular topk files --- .../moe/moe_routing/grouped_topk.py | 235 ------------------ .../_triton_kernels/moe/moe_routing/topk.py | 229 +++++++++++++++++ .../triton/moe/moe_routing/grouped_topk.py | 220 ---------------- aiter/ops/triton/moe/moe_routing/routing.py | 2 +- aiter/ops/triton/moe/moe_routing/topk.py | 172 ++++++++++++- 5 files changed, 401 insertions(+), 457 deletions(-) delete mode 100644 aiter/ops/triton/_triton_kernels/moe/moe_routing/grouped_topk.py delete mode 100644 aiter/ops/triton/moe/moe_routing/grouped_topk.py diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_routing/grouped_topk.py b/aiter/ops/triton/_triton_kernels/moe/moe_routing/grouped_topk.py deleted file mode 100644 index daa0c1e3dc..0000000000 --- a/aiter/ops/triton/_triton_kernels/moe/moe_routing/grouped_topk.py +++ /dev/null @@ -1,235 +0,0 @@ -from __future__ import annotations - -import torch -import triton -import triton.language as tl - -from aiter.ops.triton.moe.moe_routing.bitmatrix import Bitmatrix - - -@triton.jit -def _grouped_topk( - X, # router logits [n_rows, n_expts_tot] (bf16/fp32) - stride_xm, - ExpertGroup, # int32 [n_expts_tot] expert→group_id - Yv, # [n_rows, N_EXPTS_ACT_PAD] selected weights - Yi, # [n_rows, N_EXPTS_ACT_PAD] selected expert ids (int16) - stride_ym, - Bits, # bitmatrix data - stride_rm, - stride_rn, - n_rows, - n_expts_tot, - S, # bitmatrix scratchpad — must memset to 0 - BLOCK_S: tl.constexpr, - s_blocks, - SP, # bitmatrix partials — must memset to 0 - BLOCK_SP: tl.constexpr, - sp_blocks, - sp_size, - BLOCK_M: tl.constexpr, - N_EXPTS_PAD: tl.constexpr, - BLOCK_N: tl.constexpr, - N_EXPTS_ACT: tl.constexpr, - N_EXPTS_ACT_PAD: tl.constexpr, - NUM_EXPERT_GROUP: tl.constexpr, - TOPK_GROUP: tl.constexpr, - Bias=None, - SCORE_MODE: tl.constexpr = "softmax", - HAS_BIAS: tl.constexpr = False, - APPLY_RENORM: tl.constexpr = False, - ROUTED_SCALING: tl.constexpr = 1.0, - N_SHARED: tl.constexpr = 0, - SHARED_SCORE: tl.constexpr = 1.0, -): - pid = tl.program_id(0) - - # -- Memset bitmatrix scratchpads (same idiom as _topk / _hash_routing). - if pid < s_blocks: - tl.store( - S + BLOCK_S * pid + tl.arange(0, BLOCK_S), - tl.zeros([BLOCK_S], tl.int32), - ) - elif pid < s_blocks + sp_blocks: - offs = BLOCK_SP * (pid - s_blocks) + tl.arange(0, BLOCK_SP) - tl.store(SP + offs, tl.zeros([BLOCK_SP], tl.int32), mask=offs < sp_size) - - if pid * BLOCK_M >= n_rows: - return - - tl.static_assert(BLOCK_N % 32 == 0) - tl.static_assert( - N_EXPTS_PAD == BLOCK_N, - "DeepSeek-class envelope: BLOCK_N must equal N_EXPTS_PAD (single-block).", - ) - - x_dtype: tl.constexpr = X.dtype.element_ty - - offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) - mask_m = offs_m < n_rows - offs_n = tl.arange(0, BLOCK_N) - mask_n = offs_n < n_expts_tot - - # -- 1. Load logits. - X_ptrs = X + offs_m[:, None] * stride_xm + offs_n[None, :] - x = tl.load(X_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0) - - # -- 2. Score transform. - if SCORE_MODE == "softmax": - # Numerically-stable row softmax with masked-out lanes set to -inf. - x_f = tl.where(mask_n[None, :], x.to(tl.float32), float("-inf")) - x_max = tl.max(x_f, axis=1, keep_dims=True) - x_e = tl.exp(x_f - x_max) - x_e = tl.where(mask_n[None, :], x_e, 0.0) - scores = x_e / (tl.sum(x_e, axis=1, keep_dims=True) + 1e-30) - elif SCORE_MODE == "sigmoid": - scores = 1.0 / (1.0 + tl.exp(-x.to(tl.float32))) - elif SCORE_MODE == "sqrtsoftplus": - x_f = x.to(tl.float32) - sp = tl.maximum(x_f, 0.0) + tl.log(1.0 + tl.exp(-tl.abs(x_f))) - scores = tl.sqrt(sp) - else: - scores = x.to(tl.float32) - - # Pad-lane safety: invalid columns must lose every comparison. - scores = tl.where(mask_n[None, :], scores, float("-inf")) - - # -- 3. Bias-augmented choice scores. Weights are gathered later from the - # untouched ``scores`` (matches biased_grouped_topk_torch + - # FusedMoE.select_experts sigmoid path: select on s+b, return s). - if HAS_BIAS: - b = tl.load(Bias + offs_n, mask=mask_n, other=0.0).to(tl.float32) - scores_for_choice = scores + b[None, :] - else: - scores_for_choice = scores - - # -- 4. Per-group reduction over arbitrary expert→group mapping. - gid = tl.load(ExpertGroup + offs_n, mask=mask_n, other=0).to(tl.int32) - g_arange = tl.arange(0, NUM_EXPERT_GROUP) - gid_eq = gid[:, None] == g_arange[None, :] # [BLOCK_N, NUM_EXPERT_GROUP] - - # 3-D one-hot expand: [BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP], with -inf - # outside each group's column. - sfc_3d = scores_for_choice[:, :, None].broadcast_to( - BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP - ) - expanded = tl.where(gid_eq[None, :, :], sfc_3d, float("-inf")) - group_max1 = tl.max(expanded, axis=1) # [BLOCK_M, NUM_EXPERT_GROUP] - - if HAS_BIAS: - # Top-2-sum-per-group. To find the second-largest score per group - # without tl.argmax-on-3D, suppress the per-group max by exact-equality - # match (ties on float scores are negligible in DeepSeek workloads). - gm1_per_e = tl.sum( - gid_eq[None, :, :].to(tl.float32) * group_max1[:, None, :], - axis=2, - ) # [BLOCK_M, BLOCK_N] - suppressed = tl.where( - scores_for_choice == gm1_per_e, float("-inf"), scores_for_choice - ) - sup_3d = suppressed[:, :, None].broadcast_to( - BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP - ) - expanded2 = tl.where(gid_eq[None, :, :], sup_3d, float("-inf")) - group_max2 = tl.max(expanded2, axis=1) - group_scores = group_max1 + group_max2 - else: - group_scores = group_max1 - - # -- 5. Top ``TOPK_GROUP`` groups via repeated argmax (NUM_EXPERT_GROUP - # is small; static-range unroll). - group_mask_i = tl.zeros([BLOCK_M, NUM_EXPERT_GROUP], dtype=tl.int32) - gs = group_scores - for _gj in tl.static_range(TOPK_GROUP): - am_g = tl.argmax(gs, axis=1).to(tl.int32) # [BLOCK_M] - sel_g = (g_arange[None, :] == am_g[:, None]) # [BLOCK_M, NUM_EXPERT_GROUP] - group_mask_i = group_mask_i | sel_g.to(tl.int32) - gs = tl.where(sel_g, float("-inf"), gs) - - # -- 6. Per-(token, expert) keep-mask via group-id lookup, then suppress - # experts in non-selected groups on the bias-augmented scores. - expert_keep = tl.sum( - gid_eq[None, :, :].to(tl.int32) * group_mask_i[:, None, :], - axis=2, - ) > 0 # [BLOCK_M, BLOCK_N] - sfc_masked = tl.where(expert_keep, scores_for_choice, float("-inf")) - - # -- 7. Per-expert top-``N_EXPTS_ACT`` via repeated argmax. Padded slots - # (N_EXPTS_ACT_PAD > N_EXPTS_ACT) are kept in the y_indices/y_values - # buffers but masked off on the writeback / bitmatrix-pack. - n_arange = tl.arange(0, BLOCK_N) - y_indices = tl.zeros([BLOCK_M, N_EXPTS_ACT_PAD], dtype=tl.int32) - sfc_iter = sfc_masked - for kj in tl.static_range(N_EXPTS_ACT): - am_k = tl.argmax(sfc_iter, axis=1).to(tl.int32) # [BLOCK_M] - slot_eq = (tl.arange(0, N_EXPTS_ACT_PAD) == kj)[None, :] - y_indices = tl.where(slot_eq, am_k[:, None], y_indices) - sfc_iter = tl.where( - n_arange[None, :] == am_k[:, None], float("-inf"), sfc_iter - ) - - # -- 8. Gather UNBIASED weights at selected indices. - pos_eq = ( - n_arange[None, None, :] == y_indices[:, :, None] - ) # [BLOCK_M, K_PAD, BLOCK_N] - scores_3d = scores[:, None, :].broadcast_to(BLOCK_M, N_EXPTS_ACT_PAD, BLOCK_N) - y_weights = tl.sum(tl.where(pos_eq, scores_3d, 0.0), axis=2) # [BLOCK_M, K_PAD] - - # Routed-slot mask: the first N_EXPTS_ACT slots hold the grouped-topk - # selection (shared experts, if any, occupy the next N_SHARED slots and - # must be excluded from the routed renorm denominator). - k_arange = tl.arange(0, N_EXPTS_ACT_PAD) - routed_mask = k_arange[None, :] < N_EXPTS_ACT - - # -- 9. Renorm + scale over the ROUTED slots only (mirrors _topk's - # APPLY_RENORM / ROUTED_SCALING and the noaux_tc semantics where the - # always-on shared expert is appended unscaled after renorm). - if APPLY_RENORM: - y_f = tl.where(routed_mask, y_weights, 0.0) - s = tl.sum(y_f, axis=1, keep_dims=True) - y_weights = y_f / (s + 1e-20) * ROUTED_SCALING - elif ROUTED_SCALING != 1.0: - y_weights = y_weights * ROUTED_SCALING - - # -- 9b. Append fused shared expert(s): always-on, fixed id n_expts_tot+i - # and fixed weight SHARED_SCORE (matches init_aiter_topK_meta_data / - # rocm_aiter_grouped_topk). Placed AFTER renorm so the shared weight - # is not folded into the routed normalization. - if N_SHARED > 0: - shared_slot = (k_arange[None, :] >= N_EXPTS_ACT) & ( - k_arange[None, :] < N_EXPTS_ACT + N_SHARED - ) - shared_idx = (n_expts_tot + k_arange - N_EXPTS_ACT)[None, :].to(tl.int32) - y_indices = tl.where(shared_slot, shared_idx, y_indices) - y_weights = tl.where(shared_slot, SHARED_SCORE, y_weights) - real_mask = k_arange[None, :] < (N_EXPTS_ACT + N_SHARED) - else: - real_mask = routed_mask - - y_values_out = y_weights.to(x_dtype) - - # -- 10. Writeback selected weights / indices. - Yv_ptrs = Yv + offs_m[:, None] * stride_ym + k_arange[None, :] - Yi_ptrs = Yi + offs_m[:, None] * stride_ym + k_arange[None, :] - write_mask = mask_m[:, None] & real_mask - tl.store(Yv_ptrs, y_values_out, mask=write_mask) - tl.store(Yi_ptrs, y_indices, mask=write_mask) - - # -- 11. Pack into bitmatrix (mirrors _topk's tail). - safe_idx = tl.where(real_mask, y_indices, 0).to(tl.uint32) - y_div = safe_idx // 32 - y_rem = safe_idx % 32 - bm_iters: tl.constexpr = N_EXPTS_PAD // BLOCK_N # = 1 (single-block) - for i in range(bm_iters): - offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32) - y2 = tl.where( - (y_div[:, :, None] == offs_r_n[None, None, :]) & real_mask[:, :, None], - (1 << y_rem)[:, :, None], - 0, - ) - r = tl.reduce_or(y2, axis=1) - BitsPtrs = ( - Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn - ) - tl.store(BitsPtrs, r, mask=mask_m[:, None]) - diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py index 539380f240..86abb99b52 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py @@ -426,3 +426,232 @@ def _hash_routing( r = tl.reduce_or(y2, axis=1) BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn tl.store(BitsPtrs, r, mask=mask_m[:, None]) + + + +@triton.jit +def _grouped_topk( + X, # router logits [n_rows, n_expts_tot] (bf16/fp32) + stride_xm, + ExpertGroup, # int32 [n_expts_tot] expert→group_id + Yv, # [n_rows, N_EXPTS_ACT_PAD] selected weights + Yi, # [n_rows, N_EXPTS_ACT_PAD] selected expert ids (int16) + stride_ym, + Bits, # bitmatrix data + stride_rm, + stride_rn, + n_rows, + n_expts_tot, + S, # bitmatrix scratchpad — must memset to 0 + BLOCK_S: tl.constexpr, + s_blocks, + SP, # bitmatrix partials — must memset to 0 + BLOCK_SP: tl.constexpr, + sp_blocks, + sp_size, + BLOCK_M: tl.constexpr, + N_EXPTS_PAD: tl.constexpr, + BLOCK_N: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, + N_EXPTS_ACT_PAD: tl.constexpr, + NUM_EXPERT_GROUP: tl.constexpr, + TOPK_GROUP: tl.constexpr, + Bias=None, + SCORE_MODE: tl.constexpr = "softmax", + HAS_BIAS: tl.constexpr = False, + APPLY_RENORM: tl.constexpr = False, + ROUTED_SCALING: tl.constexpr = 1.0, + N_SHARED: tl.constexpr = 0, + SHARED_SCORE: tl.constexpr = 1.0, +): + pid = tl.program_id(0) + + # -- Memset bitmatrix scratchpads (same idiom as _topk / _hash_routing). + if pid < s_blocks: + tl.store( + S + BLOCK_S * pid + tl.arange(0, BLOCK_S), + tl.zeros([BLOCK_S], tl.int32), + ) + elif pid < s_blocks + sp_blocks: + offs = BLOCK_SP * (pid - s_blocks) + tl.arange(0, BLOCK_SP) + tl.store(SP + offs, tl.zeros([BLOCK_SP], tl.int32), mask=offs < sp_size) + + if pid * BLOCK_M >= n_rows: + return + + tl.static_assert(BLOCK_N % 32 == 0) + tl.static_assert( + N_EXPTS_PAD == BLOCK_N, + "DeepSeek-class envelope: BLOCK_N must equal N_EXPTS_PAD (single-block).", + ) + + x_dtype: tl.constexpr = X.dtype.element_ty + + offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = offs_m < n_rows + offs_n = tl.arange(0, BLOCK_N) + mask_n = offs_n < n_expts_tot + + # -- 1. Load logits. + X_ptrs = X + offs_m[:, None] * stride_xm + offs_n[None, :] + x = tl.load(X_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0) + + # -- 2. Score transform. + if SCORE_MODE == "softmax": + # Numerically-stable row softmax with masked-out lanes set to -inf. + x_f = tl.where(mask_n[None, :], x.to(tl.float32), float("-inf")) + x_max = tl.max(x_f, axis=1, keep_dims=True) + x_e = tl.exp(x_f - x_max) + x_e = tl.where(mask_n[None, :], x_e, 0.0) + scores = x_e / (tl.sum(x_e, axis=1, keep_dims=True) + 1e-30) + elif SCORE_MODE == "sigmoid": + scores = 1.0 / (1.0 + tl.exp(-x.to(tl.float32))) + elif SCORE_MODE == "sqrtsoftplus": + x_f = x.to(tl.float32) + sp = tl.maximum(x_f, 0.0) + tl.log(1.0 + tl.exp(-tl.abs(x_f))) + scores = tl.sqrt(sp) + else: + scores = x.to(tl.float32) + + # Pad-lane safety: invalid columns must lose every comparison. + scores = tl.where(mask_n[None, :], scores, float("-inf")) + + # -- 3. Bias-augmented choice scores. Weights are gathered later from the + # untouched ``scores`` (matches biased_grouped_topk_torch + + # FusedMoE.select_experts sigmoid path: select on s+b, return s). + if HAS_BIAS: + b = tl.load(Bias + offs_n, mask=mask_n, other=0.0).to(tl.float32) + scores_for_choice = scores + b[None, :] + else: + scores_for_choice = scores + + # -- 4. Per-group reduction over arbitrary expert→group mapping. + gid = tl.load(ExpertGroup + offs_n, mask=mask_n, other=0).to(tl.int32) + g_arange = tl.arange(0, NUM_EXPERT_GROUP) + gid_eq = gid[:, None] == g_arange[None, :] # [BLOCK_N, NUM_EXPERT_GROUP] + + # 3-D one-hot expand: [BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP], with -inf + # outside each group's column. + sfc_3d = scores_for_choice[:, :, None].broadcast_to( + BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP + ) + expanded = tl.where(gid_eq[None, :, :], sfc_3d, float("-inf")) + group_max1 = tl.max(expanded, axis=1) # [BLOCK_M, NUM_EXPERT_GROUP] + + if HAS_BIAS: + # Top-2-sum-per-group. To find the second-largest score per group + # without tl.argmax-on-3D, suppress the per-group max by exact-equality + # match (ties on float scores are negligible in DeepSeek workloads). + gm1_per_e = tl.sum( + gid_eq[None, :, :].to(tl.float32) * group_max1[:, None, :], + axis=2, + ) # [BLOCK_M, BLOCK_N] + suppressed = tl.where( + scores_for_choice == gm1_per_e, float("-inf"), scores_for_choice + ) + sup_3d = suppressed[:, :, None].broadcast_to( + BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP + ) + expanded2 = tl.where(gid_eq[None, :, :], sup_3d, float("-inf")) + group_max2 = tl.max(expanded2, axis=1) + group_scores = group_max1 + group_max2 + else: + group_scores = group_max1 + + # -- 5. Top ``TOPK_GROUP`` groups via repeated argmax (NUM_EXPERT_GROUP + # is small; static-range unroll). + group_mask_i = tl.zeros([BLOCK_M, NUM_EXPERT_GROUP], dtype=tl.int32) + gs = group_scores + for _gj in tl.static_range(TOPK_GROUP): + am_g = tl.argmax(gs, axis=1).to(tl.int32) # [BLOCK_M] + sel_g = (g_arange[None, :] == am_g[:, None]) # [BLOCK_M, NUM_EXPERT_GROUP] + group_mask_i = group_mask_i | sel_g.to(tl.int32) + gs = tl.where(sel_g, float("-inf"), gs) + + # -- 6. Per-(token, expert) keep-mask via group-id lookup, then suppress + # experts in non-selected groups on the bias-augmented scores. + expert_keep = tl.sum( + gid_eq[None, :, :].to(tl.int32) * group_mask_i[:, None, :], + axis=2, + ) > 0 # [BLOCK_M, BLOCK_N] + sfc_masked = tl.where(expert_keep, scores_for_choice, float("-inf")) + + # -- 7. Per-expert top-``N_EXPTS_ACT`` via repeated argmax. Padded slots + # (N_EXPTS_ACT_PAD > N_EXPTS_ACT) are kept in the y_indices/y_values + # buffers but masked off on the writeback / bitmatrix-pack. + n_arange = tl.arange(0, BLOCK_N) + y_indices = tl.zeros([BLOCK_M, N_EXPTS_ACT_PAD], dtype=tl.int32) + sfc_iter = sfc_masked + for kj in tl.static_range(N_EXPTS_ACT): + am_k = tl.argmax(sfc_iter, axis=1).to(tl.int32) # [BLOCK_M] + slot_eq = (tl.arange(0, N_EXPTS_ACT_PAD) == kj)[None, :] + y_indices = tl.where(slot_eq, am_k[:, None], y_indices) + sfc_iter = tl.where( + n_arange[None, :] == am_k[:, None], float("-inf"), sfc_iter + ) + + # -- 8. Gather UNBIASED weights at selected indices. + pos_eq = ( + n_arange[None, None, :] == y_indices[:, :, None] + ) # [BLOCK_M, K_PAD, BLOCK_N] + scores_3d = scores[:, None, :].broadcast_to(BLOCK_M, N_EXPTS_ACT_PAD, BLOCK_N) + y_weights = tl.sum(tl.where(pos_eq, scores_3d, 0.0), axis=2) # [BLOCK_M, K_PAD] + + # Routed-slot mask: the first N_EXPTS_ACT slots hold the grouped-topk + # selection (shared experts, if any, occupy the next N_SHARED slots and + # must be excluded from the routed renorm denominator). + k_arange = tl.arange(0, N_EXPTS_ACT_PAD) + routed_mask = k_arange[None, :] < N_EXPTS_ACT + + # -- 9. Renorm + scale over the ROUTED slots only (mirrors _topk's + # APPLY_RENORM / ROUTED_SCALING and the noaux_tc semantics where the + # always-on shared expert is appended unscaled after renorm). + if APPLY_RENORM: + y_f = tl.where(routed_mask, y_weights, 0.0) + s = tl.sum(y_f, axis=1, keep_dims=True) + y_weights = y_f / (s + 1e-20) * ROUTED_SCALING + elif ROUTED_SCALING != 1.0: + y_weights = y_weights * ROUTED_SCALING + + # -- 9b. Append fused shared expert(s): always-on, fixed id n_expts_tot+i + # and fixed weight SHARED_SCORE (matches init_aiter_topK_meta_data / + # rocm_aiter_grouped_topk). Placed AFTER renorm so the shared weight + # is not folded into the routed normalization. + if N_SHARED > 0: + shared_slot = (k_arange[None, :] >= N_EXPTS_ACT) & ( + k_arange[None, :] < N_EXPTS_ACT + N_SHARED + ) + shared_idx = (n_expts_tot + k_arange - N_EXPTS_ACT)[None, :].to(tl.int32) + y_indices = tl.where(shared_slot, shared_idx, y_indices) + y_weights = tl.where(shared_slot, SHARED_SCORE, y_weights) + real_mask = k_arange[None, :] < (N_EXPTS_ACT + N_SHARED) + else: + real_mask = routed_mask + + y_values_out = y_weights.to(x_dtype) + + # -- 10. Writeback selected weights / indices. + Yv_ptrs = Yv + offs_m[:, None] * stride_ym + k_arange[None, :] + Yi_ptrs = Yi + offs_m[:, None] * stride_ym + k_arange[None, :] + write_mask = mask_m[:, None] & real_mask + tl.store(Yv_ptrs, y_values_out, mask=write_mask) + tl.store(Yi_ptrs, y_indices, mask=write_mask) + + # -- 11. Pack into bitmatrix (mirrors _topk's tail). + safe_idx = tl.where(real_mask, y_indices, 0).to(tl.uint32) + y_div = safe_idx // 32 + y_rem = safe_idx % 32 + bm_iters: tl.constexpr = N_EXPTS_PAD // BLOCK_N # = 1 (single-block) + for i in range(bm_iters): + offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32) + y2 = tl.where( + (y_div[:, :, None] == offs_r_n[None, None, :]) & real_mask[:, :, None], + (1 << y_rem)[:, :, None], + 0, + ) + r = tl.reduce_or(y2, axis=1) + BitsPtrs = ( + Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn + ) + tl.store(BitsPtrs, r, mask=mask_m[:, None]) + diff --git a/aiter/ops/triton/moe/moe_routing/grouped_topk.py b/aiter/ops/triton/moe/moe_routing/grouped_topk.py deleted file mode 100644 index 2c72039ad7..0000000000 --- a/aiter/ops/triton/moe/moe_routing/grouped_topk.py +++ /dev/null @@ -1,220 +0,0 @@ -from __future__ import annotations - -import torch -import triton -import triton.language as tl - -from aiter.ops.triton.moe.moe_routing.bitmatrix import Bitmatrix -from aiter.ops.triton._triton_kernels.moe.moe_routing.grouped_topk import _grouped_topk - -# SPDX-License-Identifier: MIT -"""Single-fused Triton grouped-top-k routing kernel. - -Drop-in replacement for the ``topk(...)`` call inside -``aiter/ops/triton/moe/moe_routing/routing.py::routing_a8w4`` (lines 338-347). -Same return contract — ``(y_vals, y_indx, Bitmatrix)`` — so downstream -``sort_tokens`` / ``sort_tokens_fused`` consume the output unchanged. - -Algorithm (single kernel launch, mirrors the structure of aiter's ``_topk`` -and ``_hash_routing`` in ``_triton_kernels/moe/moe_routing/topk.py``): - - 1. Memset bitmatrix scratchpad / partials (same lane-borrowing trick as - ``_topk``: the first ``s_blocks + sp_blocks`` programs do nothing but - zero-fill). - 2. Load the row of router logits. - 3. Apply ``score_mode`` per element ('softmax' / 'sigmoid' / 'sqrtsoftplus' / - 'none'). - 4. Per-group score reduction over an *arbitrary* expert→group mapping - (``ExpertGroup`` int32 table): - - HAS_BIAS → top-2 sum on bias-augmented scores (DeepSeek-V3 rule; - mirrors ``biased_grouped_topk_torch``). - - else → per-group max (DeepSeek-V2 rule; mirrors - ``grouped_topk_torch``). - 5. Pick top ``TOPK_GROUP`` groups via repeated argmax (NUM_EXPERT_GROUP is - small, so the unrolled loop is tiny). - 6. Mask experts in non-selected groups to ``-inf`` on the bias-augmented - scores, then do per-expert top-``N_EXPTS_ACT`` via repeated argmax. - 7. Gather *unbiased* weights at the selected indices (matches the - ``noaux_tc`` semantics — bias used for selection only, weights from the - untouched score). - 8. Optional renorm + ``routed_scaling_factor`` scale. - 9. Pack selected indices into the (n_cols_words, n_rows_pad32).T uint32 - bitmatrix layout the kernel emits, identical to ``_topk``. - -Constraints (DeepSeek-class envelope): - - n_expts_tot ≤ 256 (single ``BLOCK_N`` pass; no streaming loop). - - num_expert_group ≤ 16. - - topk_group ≤ num_expert_group. - - n_expts_act (top_k) ≤ 16. - - BLOCK_M = 1 (the per-group 3-D intermediate is - ``[BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP]`` fp32 — at BLOCK_M=1 that's - ≤ 256 * 16 * 4 = 16 KiB, fits in registers / LDS comfortably). -""" - -def grouped_topk( - x: torch.Tensor, - k: int, - num_expert_group: int, - topk_group: int, - *, - expert_group: torch.Tensor | None = None, - apply_softmax: bool = False, # accepted for parity with topk(); ignored - HIST_BLOCK_M: int = 32, - score_mode: str = "softmax", - bias: torch.Tensor | None = None, - renorm: bool = False, - routed_scaling_factor: float = 1.0, - num_fused_shared_experts: int = 0, - shared_experts_score: float = 1.0, -): - """Triton grouped top-k expert selection. See module docstring. - - Returns ``(y_vals, y_indx, bitmatrix)`` matching the contract of - ``aiter.ops.triton.moe.moe_routing.topk.topk``: - - - y_vals: ``(n_rows, k + num_fused_shared_experts)`` in ``x.dtype``. - - y_indx: ``(n_rows, k + num_fused_shared_experts)`` ``int16``. - - When ``num_fused_shared_experts > 0`` the routed top-k selection occupies - the first ``k`` columns and the always-on shared expert(s) occupy the next - ``num_fused_shared_experts`` columns — expert id ``n_cols + i``, weight - ``shared_experts_score`` (appended after the routed renorm, mirroring - ``init_aiter_topK_meta_data`` / ``rocm_aiter_grouped_topk``). The bitmatrix - is widened to ``n_cols + num_fused_shared_experts`` columns so ``sort_tokens`` - counts the shared bucket. - - - bitmatrix: real :class:`Bitmatrix`; same uint32 - ``(n_cols_words, n_rows_pad32).T`` storage / scratchpad layout the - ``_topk`` kernel emits, so ``sort_tokens`` and ``sort_tokens_fused`` - consume it unchanged. - """ - assert x.dim() == 2 - n_rows, n_cols = x.shape - assert n_cols <= 256, ( - f"DeepSeek-class envelope: n_expts_tot ({n_cols}) must be <= 256" - ) - # Fused shared experts are appended (always-on) AFTER the routed selection; - # they occupy expert ids [n_cols, n_cols + num_fused_shared_experts). - n_shared = num_fused_shared_experts - assert n_shared >= 0 - n_total = n_cols + n_shared # experts incl. shared (bitmatrix width) - k_out = k + n_shared # output width (routed top-k + shared) - assert num_expert_group > 1 - assert num_expert_group <= 16, ( - f"NUM_EXPERT_GROUP ({num_expert_group}) > 16 not supported" - ) - assert 0 < topk_group <= num_expert_group - assert 0 < k <= 16 - assert score_mode in ("softmax", "sigmoid", "sqrtsoftplus", "none"), ( - f"unknown score_mode {score_mode!r}" - ) - has_bias = bias is not None - if has_bias: - assert bias.dim() == 1 and bias.shape[0] == n_cols - assert bias.dtype == torch.float32 - assert score_mode in ("sqrtsoftplus", "sigmoid"), ( - "bias only supported with sqrtsoftplus / sigmoid" - ) - - dev = x.device - - # Default expert→group mapping = contiguous DeepSeek layout. - if expert_group is None: - assert n_cols % num_expert_group == 0, ( - f"n_expts_tot ({n_cols}) not divisible by num_expert_group " - f"({num_expert_group}); pass an explicit expert_group table." - ) - g_size = n_cols // num_expert_group - expert_group = ( - torch.arange(n_cols, device=dev, dtype=torch.int32) // g_size - ).to(torch.int32) - else: - assert expert_group.dim() == 1 and expert_group.shape[0] == n_cols - assert expert_group.dtype == torch.int32 - - # Block sizes — single BLOCK_N pass for DeepSeek envelope. BLOCK_N must - # cover the shared-expert columns too so their bits fit in the bitmatrix. - BLOCK_M = 1 - BLOCK_N = max(32, triton.next_power_of_2(n_total)) - N_EXPTS_PAD = BLOCK_N - # Mirror topk(): pad to ≥ 2 to dodge tl.argmax/topk(k=1) compile quirks. - N_EXPTS_ACT_PAD = max(2, triton.next_power_of_2(k_out)) - BLOCK_S = 128 - BLOCK_SP = 128 - TILE_SIZE = 8 - - # Outputs (same shapes / dtypes as topk(...)), widened by the shared slots. - y_vals = torch.empty((n_rows, k_out), dtype=x.dtype, device=dev) - y_indx = torch.empty((n_rows, k_out), dtype=torch.int16, device=dev) - - # Bitmatrix in transposed-uint32 storage layout (identical to topk()). - n_cols_pad = triton.cdiv(n_total, BLOCK_N) * BLOCK_N - n_cols_words = n_cols_pad // 32 - bitmatrix_data = torch.empty( - (n_cols_words, triton.cdiv(n_rows, 32) * 32), - dtype=torch.uint32, - device=dev, - ) - bitmatrix_data = torch.transpose(bitmatrix_data, 0, 1)[:n_rows] - - # Scratchpads. The per-column sum buffer consumed by Bitmatrix.sum() / - # sort_tokens must cover the full padded column count (n_cols_pad), which - # widens with the shared experts; sizing by n_total alone can under-allocate - # (e.g. n_total=257 -> n_cols_pad=512 but cdiv(257,128)*128=384). - s_blocks = triton.cdiv(n_cols_pad, BLOCK_S) - s_cols = s_blocks * BLOCK_S - scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev) - BLOCK_MM = HIST_BLOCK_M * TILE_SIZE - pids_x = triton.cdiv(n_rows, BLOCK_MM) - scratchpad_partials = torch.empty( - (n_cols_pad, pids_x * TILE_SIZE), dtype=torch.int32, device=dev - ) - scratchpad_partials = torch.transpose(scratchpad_partials, 0, 1) - sp_size = scratchpad_partials.numel() - sp_blocks = triton.cdiv(sp_size, BLOCK_SP) - - pids = max(triton.cdiv(n_rows, BLOCK_M), s_blocks + sp_blocks) - - _grouped_topk[(pids,)]( - x, - x.stride(0), - expert_group, - y_vals, - y_indx, - y_vals.stride(0), - bitmatrix_data, - bitmatrix_data.stride(0), - bitmatrix_data.stride(1), - n_rows, - n_cols, - scratchpad, - BLOCK_S, - s_blocks, - scratchpad_partials, - BLOCK_SP, - sp_blocks, - sp_size, - BLOCK_M=BLOCK_M, - N_EXPTS_PAD=N_EXPTS_PAD, - BLOCK_N=BLOCK_N, - N_EXPTS_ACT=k, - N_EXPTS_ACT_PAD=N_EXPTS_ACT_PAD, - NUM_EXPERT_GROUP=num_expert_group, - TOPK_GROUP=topk_group, - Bias=bias, - SCORE_MODE=score_mode, - HAS_BIAS=has_bias, - APPLY_RENORM=renorm, - ROUTED_SCALING=routed_scaling_factor, - N_SHARED=n_shared, - SHARED_SCORE=shared_experts_score, - num_warps=4, - ) - - bitmatrix = Bitmatrix( - bitmatrix_data, - shape=[n_rows, n_cols_words * 32], - scratchpad=scratchpad, - scratchpad_partials=scratchpad_partials, - ) - return y_vals, y_indx, bitmatrix diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index d8f2fef926..d62fa33b78 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -367,7 +367,7 @@ def routing_a8w4( num_expert_group is not None and topk_group is not None ), "use_grouped_topk requires num_expert_group and topk_group" # Lazy import: ATOM-side kernel; avoids hard aiter→atom import order. - from aiter.ops.triton.moe.moe_routing.grouped_topk import grouped_topk + from aiter.ops.triton.moe.moe_routing.topk import grouped_topk expt_scal, expt_indx, bitmatrix = grouped_topk( logits, diff --git a/aiter/ops/triton/moe/moe_routing/topk.py b/aiter/ops/triton/moe/moe_routing/topk.py index 4306184f7e..a42be56f6b 100644 --- a/aiter/ops/triton/moe/moe_routing/topk.py +++ b/aiter/ops/triton/moe/moe_routing/topk.py @@ -1,9 +1,179 @@ import triton import torch -from aiter.ops.triton._triton_kernels.moe.moe_routing.topk import _topk, _hash_routing +from aiter.ops.triton._triton_kernels.moe.moe_routing.topk import _topk, _hash_routing, _grouped_topk from aiter.ops.triton.moe.moe_routing.bitmatrix import Bitmatrix +def grouped_topk( + x: torch.Tensor, + k: int, + num_expert_group: int, + topk_group: int, + *, + expert_group: torch.Tensor | None = None, + apply_softmax: bool = False, # accepted for parity with topk(); ignored + HIST_BLOCK_M: int = 32, + score_mode: str = "softmax", + bias: torch.Tensor | None = None, + renorm: bool = False, + routed_scaling_factor: float = 1.0, + num_fused_shared_experts: int = 0, + shared_experts_score: float = 1.0, +): + """Triton grouped top-k expert selection. See module docstring. + + Returns ``(y_vals, y_indx, bitmatrix)`` matching the contract of + ``aiter.ops.triton.moe.moe_routing.topk.topk``: + + - y_vals: ``(n_rows, k + num_fused_shared_experts)`` in ``x.dtype``. + - y_indx: ``(n_rows, k + num_fused_shared_experts)`` ``int16``. + + When ``num_fused_shared_experts > 0`` the routed top-k selection occupies + the first ``k`` columns and the always-on shared expert(s) occupy the next + ``num_fused_shared_experts`` columns — expert id ``n_cols + i``, weight + ``shared_experts_score`` (appended after the routed renorm, mirroring + ``init_aiter_topK_meta_data`` / ``rocm_aiter_grouped_topk``). The bitmatrix + is widened to ``n_cols + num_fused_shared_experts`` columns so ``sort_tokens`` + counts the shared bucket. + + - bitmatrix: real :class:`Bitmatrix`; same uint32 + ``(n_cols_words, n_rows_pad32).T`` storage / scratchpad layout the + ``_topk`` kernel emits, so ``sort_tokens`` and ``sort_tokens_fused`` + consume it unchanged. + """ + assert x.dim() == 2 + n_rows, n_cols = x.shape + assert n_cols <= 256, ( + f"DeepSeek-class envelope: n_expts_tot ({n_cols}) must be <= 256" + ) + # Fused shared experts are appended (always-on) AFTER the routed selection; + # they occupy expert ids [n_cols, n_cols + num_fused_shared_experts). + n_shared = num_fused_shared_experts + assert n_shared >= 0 + n_total = n_cols + n_shared # experts incl. shared (bitmatrix width) + k_out = k + n_shared # output width (routed top-k + shared) + assert num_expert_group > 1 + assert num_expert_group <= 16, ( + f"NUM_EXPERT_GROUP ({num_expert_group}) > 16 not supported" + ) + assert 0 < topk_group <= num_expert_group + assert 0 < k <= 16 + assert score_mode in ("softmax", "sigmoid", "sqrtsoftplus", "none"), ( + f"unknown score_mode {score_mode!r}" + ) + has_bias = bias is not None + if has_bias: + assert bias.dim() == 1 and bias.shape[0] == n_cols + assert bias.dtype == torch.float32 + assert score_mode in ("sqrtsoftplus", "sigmoid"), ( + "bias only supported with sqrtsoftplus / sigmoid" + ) + + dev = x.device + + # Default expert→group mapping = contiguous DeepSeek layout. + if expert_group is None: + assert n_cols % num_expert_group == 0, ( + f"n_expts_tot ({n_cols}) not divisible by num_expert_group " + f"({num_expert_group}); pass an explicit expert_group table." + ) + g_size = n_cols // num_expert_group + expert_group = ( + torch.arange(n_cols, device=dev, dtype=torch.int32) // g_size + ).to(torch.int32) + else: + assert expert_group.dim() == 1 and expert_group.shape[0] == n_cols + assert expert_group.dtype == torch.int32 + + # Block sizes — single BLOCK_N pass for DeepSeek envelope. BLOCK_N must + # cover the shared-expert columns too so their bits fit in the bitmatrix. + BLOCK_M = 1 + BLOCK_N = max(32, triton.next_power_of_2(n_total)) + N_EXPTS_PAD = BLOCK_N + # Mirror topk(): pad to ≥ 2 to dodge tl.argmax/topk(k=1) compile quirks. + N_EXPTS_ACT_PAD = max(2, triton.next_power_of_2(k_out)) + BLOCK_S = 128 + BLOCK_SP = 128 + TILE_SIZE = 8 + + # Outputs (same shapes / dtypes as topk(...)), widened by the shared slots. + y_vals = torch.empty((n_rows, k_out), dtype=x.dtype, device=dev) + y_indx = torch.empty((n_rows, k_out), dtype=torch.int16, device=dev) + + # Bitmatrix in transposed-uint32 storage layout (identical to topk()). + n_cols_pad = triton.cdiv(n_total, BLOCK_N) * BLOCK_N + n_cols_words = n_cols_pad // 32 + bitmatrix_data = torch.empty( + (n_cols_words, triton.cdiv(n_rows, 32) * 32), + dtype=torch.uint32, + device=dev, + ) + bitmatrix_data = torch.transpose(bitmatrix_data, 0, 1)[:n_rows] + + # Scratchpads. The per-column sum buffer consumed by Bitmatrix.sum() / + # sort_tokens must cover the full padded column count (n_cols_pad), which + # widens with the shared experts; sizing by n_total alone can under-allocate + # (e.g. n_total=257 -> n_cols_pad=512 but cdiv(257,128)*128=384). + s_blocks = triton.cdiv(n_cols_pad, BLOCK_S) + s_cols = s_blocks * BLOCK_S + scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev) + BLOCK_MM = HIST_BLOCK_M * TILE_SIZE + pids_x = triton.cdiv(n_rows, BLOCK_MM) + scratchpad_partials = torch.empty( + (n_cols_pad, pids_x * TILE_SIZE), dtype=torch.int32, device=dev + ) + scratchpad_partials = torch.transpose(scratchpad_partials, 0, 1) + sp_size = scratchpad_partials.numel() + sp_blocks = triton.cdiv(sp_size, BLOCK_SP) + + pids = max(triton.cdiv(n_rows, BLOCK_M), s_blocks + sp_blocks) + + _grouped_topk[(pids,)]( + x, + x.stride(0), + expert_group, + y_vals, + y_indx, + y_vals.stride(0), + bitmatrix_data, + bitmatrix_data.stride(0), + bitmatrix_data.stride(1), + n_rows, + n_cols, + scratchpad, + BLOCK_S, + s_blocks, + scratchpad_partials, + BLOCK_SP, + sp_blocks, + sp_size, + BLOCK_M=BLOCK_M, + N_EXPTS_PAD=N_EXPTS_PAD, + BLOCK_N=BLOCK_N, + N_EXPTS_ACT=k, + N_EXPTS_ACT_PAD=N_EXPTS_ACT_PAD, + NUM_EXPERT_GROUP=num_expert_group, + TOPK_GROUP=topk_group, + Bias=bias, + SCORE_MODE=score_mode, + HAS_BIAS=has_bias, + APPLY_RENORM=renorm, + ROUTED_SCALING=routed_scaling_factor, + N_SHARED=n_shared, + SHARED_SCORE=shared_experts_score, + num_warps=4, + ) + + bitmatrix = Bitmatrix( + bitmatrix_data, + shape=[n_rows, n_cols_words * 32], + scratchpad=scratchpad, + scratchpad_partials=scratchpad_partials, + ) + return y_vals, y_indx, bitmatrix + + + def topk( x, k, From 46976ee6baf76b5efbcecd912f89d99b3917ce98 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Tue, 2 Jun 2026 18:37:18 +0000 Subject: [PATCH 05/18] fix testing --- .../triton_tests/moe/test_grouped_topk.py | 63 ++++++++++++------- 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/op_tests/triton_tests/moe/test_grouped_topk.py b/op_tests/triton_tests/moe/test_grouped_topk.py index ca3767b5bf..0465431adb 100644 --- a/op_tests/triton_tests/moe/test_grouped_topk.py +++ b/op_tests/triton_tests/moe/test_grouped_topk.py @@ -1,5 +1,5 @@ -"""Unit tests for ATOM's single-fused Triton grouped-top-k routing kernel -(``atom.model_ops.grouped_topk_triton.grouped_topk``). +"""Unit tests for aiter's single-fused Triton grouped-top-k routing kernel +(``aiter.ops.triton.moe.moe_routing.topk.grouped_topk``). Structured after ``test_moe_routing.py``: * Reference uses aiter's torch grouped-topk @@ -26,12 +26,7 @@ compute_expt_data_torch, ) -# grouped_topk lives in ATOM; skip the whole module if ATOM isn't importable -# in this environment (e.g. aiter-only CI). -atom_grouped_topk = pytest.importorskip( - "atom.model_ops.grouped_topk_triton" -).grouped_topk - +from aiter.ops.triton.moe.moe_routing.topk import grouped_topk # -------------------------------------------------------------------------- # comparison helpers (copied from test_moe_routing.py for self-containment) @@ -154,7 +149,15 @@ def _ref_contiguous( def _ref_arbitrary_grouped( - logits, expert_group, k, num_expert_group, topk_group, score_mode, bias, renorm, scale + logits, + expert_group, + k, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, ): """General reference honoring an arbitrary expert->group table (equal-size groups). Used for the non-contiguous mapping case where the aiter refs @@ -181,7 +184,9 @@ def _ref_arbitrary_grouped( group_scores[:, g] = sub.max(dim=-1).values group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False).indices - group_sel = torch.zeros((nt, num_expert_group), device=logits.device, dtype=torch.bool) + group_sel = torch.zeros( + (nt, num_expert_group), device=logits.device, dtype=torch.bool + ) group_sel.scatter_(1, group_idx, True) # expert keep mask via group table lookup expert_keep = group_sel[:, expert_group.long()] # (nt, ne) @@ -210,9 +215,9 @@ def _assert_selection_matches(ref_ids, ref_w, tri_ids, tri_w): identical and gathered weights close.""" ref_ids_s, ref_w_s = _row_sort_by_id(ref_ids.cpu(), ref_w.cpu()) tri_ids_s, tri_w_s = _row_sort_by_id(tri_ids.cpu().long(), tri_w.cpu().float()) - assert torch.equal(ref_ids_s, tri_ids_s), ( - f"selected expert ids differ:\nref={ref_ids_s}\ntri={tri_ids_s}" - ) + assert torch.equal( + ref_ids_s, tri_ids_s + ), f"selected expert ids differ:\nref={ref_ids_s}\ntri={tri_ids_s}" assert_close(ref_w_s, tri_w_s, 2e-2, 4e-3, description="weights") @@ -260,8 +265,12 @@ def _check_routing_data_bucket( ref_hist, _, _, _, ref_expt_data = ref_pack assert_equal(ref_hist, tri_routing_data.expt_hist) assert_equal(ref_expt_data.hist, tri_routing_data.expt_data.hist) - assert_equal(ref_expt_data.token_offs_raw, tri_routing_data.expt_data.token_offs_raw) - assert_equal(ref_expt_data.token_offs_pad, tri_routing_data.expt_data.token_offs_pad) + assert_equal( + ref_expt_data.token_offs_raw, tri_routing_data.expt_data.token_offs_raw + ) + assert_equal( + ref_expt_data.token_offs_pad, tri_routing_data.expt_data.token_offs_pad + ) assert_equal(ref_expt_data.block_pid_map, tri_routing_data.expt_data.block_pid_map) n_tokens, n_expts_act = topk_ids.shape @@ -269,7 +278,9 @@ def _check_routing_data_bucket( n_expts_tot = ref_hist.numel() iota = torch.arange(n_gates, dtype=torch.int32, device=tri_gather.device) - assert torch.equal(tri_scatter[tri_gather.long()], iota), "scatter[gather[j]] != j" + assert torch.equal( + tri_scatter.long()[tri_gather.long()], iota + ), "scatter[gather[j]] != j" flat_ids = topk_ids.reshape(-1).cpu().tolist() flat_w = topk_weights.reshape(-1).float().cpu().tolist() @@ -368,7 +379,7 @@ def test_grouped_topk_kernel( renorm, scale, ) - y_vals, y_indx, bitmatrix = atom_grouped_topk( + y_vals, y_indx, bitmatrix = grouped_topk( logits, n_expts_act, num_expert_group=num_expert_group, @@ -425,7 +436,7 @@ def test_grouped_topk_arbitrary_group( renorm, scale, ) - y_vals, y_indx, bitmatrix = atom_grouped_topk( + y_vals, y_indx, bitmatrix = grouped_topk( logits, n_expts_act, num_expert_group=num_expert_group, @@ -463,7 +474,7 @@ def test_routing_a8w4_grouped( # The selection the kernel makes (deterministic for fixed inputs); used as # ground truth for the sort/scatter pipeline check. - y_vals, y_indx, _ = atom_grouped_topk( + y_vals, y_indx, _ = grouped_topk( logits, n_expts_act, num_expert_group=num_expert_group, @@ -522,10 +533,16 @@ def test_grouped_topk_shared_expert( score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 ref_w, ref_ids = _ref_contiguous( - logits.clone(), n_expts_act, num_expert_group, topk_group, - score_mode, bias, renorm, scale, + logits.clone(), + n_expts_act, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, ) - y_vals, y_indx, bitmatrix = atom_grouped_topk( + y_vals, y_indx, bitmatrix = grouped_topk( logits, n_expts_act, num_expert_group=num_expert_group, @@ -600,4 +617,4 @@ def test_routing_a8w4_grouped_shared( n_gates = n_tokens * (n_expts_act + n_shared) iota = torch.arange(n_gates, dtype=torch.int32, device=gather.device) - assert torch.equal(scatter[gather.long()], iota), "scatter[gather[j]] != j" + assert torch.equal(scatter.long()[gather.long()], iota), "scatter[gather[j]] != j" From 587728db8c376a204352f798101f34788f53243f Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Wed, 3 Jun 2026 16:10:51 +0000 Subject: [PATCH 06/18] refactor routing_a8w4 name to unify dsv4 and grouped topk + sigmoid --- aiter/ops/triton/moe/moe_routing/routing.py | 2 +- .../triton_tests/moe/test_grouped_topk.py | 620 ------------------ op_tests/triton_tests/moe/test_moe_routing.py | 488 +++++++++++++- 3 files changed, 485 insertions(+), 625 deletions(-) delete mode 100644 op_tests/triton_tests/moe/test_grouped_topk.py diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index d62fa33b78..8b7741d130 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -312,7 +312,7 @@ def routing(logits, n_expts_act, sm_first=False): ) -def routing_a8w4( +def routing_ds( logits: torch.Tensor, n_expts_act: int, block_m: int, diff --git a/op_tests/triton_tests/moe/test_grouped_topk.py b/op_tests/triton_tests/moe/test_grouped_topk.py deleted file mode 100644 index 0465431adb..0000000000 --- a/op_tests/triton_tests/moe/test_grouped_topk.py +++ /dev/null @@ -1,620 +0,0 @@ -"""Unit tests for aiter's single-fused Triton grouped-top-k routing kernel -(``aiter.ops.triton.moe.moe_routing.topk.grouped_topk``). - -Structured after ``test_moe_routing.py``: - * Reference uses aiter's torch grouped-topk - (``biased_grouped_topk_torch`` / ``grouped_topk_torch``) for the standard - contiguous DeepSeek group layout, plus a thin wrapper for the - ``sqrtsoftplus`` score mode and the ``routed_scaling_factor`` scale that - the aiter refs don't apply. - * ``(y_vals, y_indx)`` are compared per-row set-wise (sorted by expert id), - robust to the kernel returning experts in descending-score order. - * The emitted ``Bitmatrix`` is decoded and checked against the selected - expert set. - * End-to-end ``routing_a8w4(use_grouped_topk=True)`` is validated through the - sort_tokens / ExptData pipeline with a bucket-multiset check. -""" - -import pytest -import torch -import torch.nn.functional as F - -from aiter.ops.triton.utils._triton.arch_info import get_arch -from aiter.ops.topk import biased_grouped_topk_torch, grouped_topk_torch -from aiter.ops.triton.moe.moe_routing.routing import ( - routing_a8w4, - compute_expt_data_torch, -) - -from aiter.ops.triton.moe.moe_routing.topk import grouped_topk - -# -------------------------------------------------------------------------- -# comparison helpers (copied from test_moe_routing.py for self-containment) -# -------------------------------------------------------------------------- - - -def assert_equal(ref, tri): - if isinstance(ref, torch.Tensor): - assert ((ref.cpu().numpy() - tri.cpu().numpy()) ** 2).sum() == 0 - else: - assert ref == tri - - -def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True): - if maxtol is None: - maxtol = 2e-2 - if rmstol is None: - rmstol = 4e-3 - ref = ref.to(torch.float32).detach() - tri = tri.to(torch.float32).detach() - assert ( - ref.shape == tri.shape - ), f"Tensors must have same size {ref.shape=} {tri.shape=}" - - inf_mask_ref = torch.isinf(ref) - inf_mask_tri = torch.isinf(tri) - assert torch.equal( - inf_mask_ref, inf_mask_tri - ), "Tensor must have same infinite elements" - refn = torch.where(inf_mask_ref, 0, ref) - trin = torch.where(inf_mask_tri, 0, tri) - - eps = 1.0e-30 - multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps) - refn *= multiplier - trin *= multiplier - - ref_rms = torch.sqrt(torch.square(refn).mean()) + eps - rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn)) - max_err = torch.max(rel_err).item() - rms_err = torch.sqrt(torch.square(rel_err).mean()).item() - - if verbose: - print(f"{description} max rel err = {max_err} (thr {maxtol})") - print(f"{description} rms rel err = {rms_err} (thr {rmstol})") - - assert max_err <= maxtol - assert rms_err <= rmstol - - -def init_data(n_tokens, n_expts_tot, dtype=torch.float32, device="cuda"): - return torch.randn((n_tokens, n_expts_tot), dtype=dtype, device=device) - - -# -------------------------------------------------------------------------- -# torch references -# -------------------------------------------------------------------------- - - -def _ref_sqrtsoftplus_grouped( - logits, bias, k, num_expert_group, topk_group, renorm, scale -): - """sqrtsoftplus grouped-topk reference (no aiter equivalent exists). - - Mirrors the kernel: sqrt(softplus(logits)) transform, bias added for - SELECTION only, top-2-sum-per-group when biased else per-group max, mask - non-selected groups, top-k on the (biased) choice scores, gather UNBIASED - weights, renorm + scale. - """ - nt, ne = logits.shape - g_size = ne // num_expert_group - transform = torch.sqrt(F.softplus(logits.float())) - choice = transform + bias.float().unsqueeze(0) if bias is not None else transform - - if bias is not None: - group_scores = ( - choice.view(nt, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) - else: - group_scores = choice.view(nt, num_expert_group, -1).max(dim=-1).values - - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False).indices - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1.0) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(nt, num_expert_group, g_size) - .reshape(nt, ne) - .bool() - ) - tmp = choice.masked_fill(~score_mask, float("-inf")) - ids = torch.topk(tmp, k=k, dim=-1, sorted=False).indices - w = transform.gather(1, ids) - if renorm: - w = w / (w.sum(dim=-1, keepdim=True) + 1e-20) - w = w * scale - return w.float(), ids.to(torch.int64) - - -def _ref_contiguous( - logits, k, num_expert_group, topk_group, score_mode, bias, renorm, scale -): - """Reference for contiguous DeepSeek group layout. Reuses aiter torch refs - where they apply, plus the sqrtsoftplus wrapper + scale.""" - if score_mode == "sqrtsoftplus": - return _ref_sqrtsoftplus_grouped( - logits, bias, k, num_expert_group, topk_group, renorm, scale - ) - if score_mode == "sigmoid" and bias is not None: - w, ids = biased_grouped_topk_torch( - logits, bias, k, renorm, num_expert_group, topk_group - ) - elif score_mode in ("sigmoid", "softmax"): - w, ids = grouped_topk_torch( - logits, k, renorm, num_expert_group, topk_group, scoring_func=score_mode - ) - else: - raise ValueError(score_mode) - return w.float() * scale, ids.to(torch.int64) - - -def _ref_arbitrary_grouped( - logits, - expert_group, - k, - num_expert_group, - topk_group, - score_mode, - bias, - renorm, - scale, -): - """General reference honoring an arbitrary expert->group table (equal-size - groups). Used for the non-contiguous mapping case where the aiter refs - (which assume contiguous .view groups) don't apply.""" - nt, ne = logits.shape - f32 = logits.float() - if score_mode == "softmax": - scores = torch.softmax(f32, dim=-1) - elif score_mode == "sigmoid": - scores = f32.sigmoid() - elif score_mode == "sqrtsoftplus": - scores = torch.sqrt(F.softplus(f32)) - else: - scores = f32 - choice = scores + bias.float().unsqueeze(0) if bias is not None else scores - - group_scores = torch.empty((nt, num_expert_group), device=logits.device) - for g in range(num_expert_group): - cols = (expert_group == g).nonzero(as_tuple=False).flatten() - sub = choice[:, cols] - if bias is not None: - group_scores[:, g] = sub.topk(2, dim=-1)[0].sum(dim=-1) - else: - group_scores[:, g] = sub.max(dim=-1).values - - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False).indices - group_sel = torch.zeros( - (nt, num_expert_group), device=logits.device, dtype=torch.bool - ) - group_sel.scatter_(1, group_idx, True) - # expert keep mask via group table lookup - expert_keep = group_sel[:, expert_group.long()] # (nt, ne) - - tmp = choice.masked_fill(~expert_keep, float("-inf")) - ids = torch.topk(tmp, k=k, dim=-1, sorted=False).indices - w = scores.gather(1, ids) - if renorm: - w = w / (w.sum(dim=-1, keepdim=True) + 1e-20) - w = w * scale - return w.float(), ids.to(torch.int64) - - -# -------------------------------------------------------------------------- -# output comparison utilities -# -------------------------------------------------------------------------- - - -def _row_sort_by_id(ids, weights): - order = torch.argsort(ids, dim=1) - return torch.gather(ids, 1, order), torch.gather(weights, 1, order) - - -def _assert_selection_matches(ref_ids, ref_w, tri_ids, tri_w): - """Set-wise per-row comparison: sort both by expert id, then assert ids - identical and gathered weights close.""" - ref_ids_s, ref_w_s = _row_sort_by_id(ref_ids.cpu(), ref_w.cpu()) - tri_ids_s, tri_w_s = _row_sort_by_id(tri_ids.cpu().long(), tri_w.cpu().float()) - assert torch.equal( - ref_ids_s, tri_ids_s - ), f"selected expert ids differ:\nref={ref_ids_s}\ntri={tri_ids_s}" - assert_close(ref_w_s, tri_w_s, 2e-2, 4e-3, description="weights") - - -def _decode_bitmatrix(bitmatrix, n_tokens, n_expts_tot): - """Decode the packed uint32 Bitmatrix into a (n_tokens, n_expts_tot) bool - matrix of selected experts.""" - data = bitmatrix.data[:n_tokens].to(torch.int64) # (n_tokens, n_cols_words) - n_cols_words = data.shape[1] - bits = torch.arange(32, device=data.device, dtype=torch.int64) - unpacked = ((data.unsqueeze(-1) >> bits) & 1).bool() # (nt, words, 32) - unpacked = unpacked.reshape(n_tokens, n_cols_words * 32) - return unpacked[:, :n_expts_tot] - - -def _assert_bitmatrix_matches(bitmatrix, tri_ids, n_tokens, n_expts_tot): - decoded = _decode_bitmatrix(bitmatrix, n_tokens, n_expts_tot).cpu() - expected = torch.zeros((n_tokens, n_expts_tot), dtype=torch.bool) - expected.scatter_(1, tri_ids.cpu().long(), True) - assert torch.equal(decoded, expected), "bitmatrix does not match selected ids" - - -# -------------------------------------------------------------------------- -# end-to-end routing helpers (mirror of test_moe_routing.py, compacted) -# -------------------------------------------------------------------------- - - -def _sort_and_build_torch(expt_scal, expt_indx, n_expts_tot, block_m): - n_tokens, n_expts_act = expt_scal.shape - n_gates = n_tokens * n_expts_act - scal_flat = expt_scal.reshape(-1) - indx_flat = expt_indx.reshape(-1).to(torch.int32) - topk_indx = torch.argsort(indx_flat, stable=True).to(torch.int32) - gate_indx = torch.argsort(topk_indx, stable=True).to(torch.int32) - gate_scal = scal_flat[topk_indx.long()] - hist = torch.histc( - indx_flat.float(), bins=n_expts_tot, min=0, max=n_expts_tot - 1 - ).int() - expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates, block_m) - return hist, topk_indx, gate_indx, gate_scal, expt_data - - -def _check_routing_data_bucket( - ref_pack, tri_routing_data, tri_gather, tri_scatter, topk_weights, topk_ids -): - ref_hist, _, _, _, ref_expt_data = ref_pack - assert_equal(ref_hist, tri_routing_data.expt_hist) - assert_equal(ref_expt_data.hist, tri_routing_data.expt_data.hist) - assert_equal( - ref_expt_data.token_offs_raw, tri_routing_data.expt_data.token_offs_raw - ) - assert_equal( - ref_expt_data.token_offs_pad, tri_routing_data.expt_data.token_offs_pad - ) - assert_equal(ref_expt_data.block_pid_map, tri_routing_data.expt_data.block_pid_map) - - n_tokens, n_expts_act = topk_ids.shape - n_gates = n_tokens * n_expts_act - n_expts_tot = ref_hist.numel() - - iota = torch.arange(n_gates, dtype=torch.int32, device=tri_gather.device) - assert torch.equal( - tri_scatter.long()[tri_gather.long()], iota - ), "scatter[gather[j]] != j" - - flat_ids = topk_ids.reshape(-1).cpu().tolist() - flat_w = topk_weights.reshape(-1).float().cpu().tolist() - src = tri_gather.cpu().tolist() - scal = tri_routing_data.gate_scal.float().cpu().tolist() - cum = torch.cumsum(ref_hist, dim=0).cpu().tolist() - - ground = {e: [] for e in range(n_expts_tot)} - for i, e in enumerate(flat_ids): - ground[e].append((i // n_expts_act, flat_w[i])) - for e in ground: - ground[e].sort() - - got = {e: [] for e in range(n_expts_tot)} - e = 0 - for j in range(n_gates): - while e < n_expts_tot and j >= cum[e]: - e += 1 - assert flat_ids[src[j]] == e, f"bucket-invariant violated at pos {j}" - got[e].append((src[j] // n_expts_act, scal[j])) - for e in got: - got[e].sort() - - for e in range(n_expts_tot): - rb, tb = ground[e], got[e] - assert len(rb) == len(tb), f"expert {e}: ref={len(rb)} test={len(tb)}" - for (tt_r, w_r), (tt_t, w_t) in zip(rb, tb): - assert tt_r == tt_t, f"expert {e}: token ref={tt_r} test={tt_t}" - assert abs(w_r - w_t) <= 1e-6, f"expert {e} token {tt_r}: w {w_r} vs {w_t}" - - -# -------------------------------------------------------------------------- -# parametrization -# -------------------------------------------------------------------------- - -# (n_expts_tot, num_expert_group, topk_group, n_expts_act) — DeepSeek-like. -GROUP_SHAPES = [ - (256, 8, 4, 8), - (128, 8, 4, 6), -] -# n_tokens spanning the fused (<=16) and regular sort_tokens paths. -N_TOKENS = [8, 16, 64, 1024] -# (score_mode, has_bias, renorm, routed_scaling_factor) — production-core set. -SCORE_COMBOS = [ - ("sqrtsoftplus", True, True, 2.5), - ("sigmoid", True, True, 1.0), - ("softmax", False, False, 1.0), -] - - -def _maybe_skip(): - if not torch.cuda.is_available(): - pytest.skip("grouped_topk requires a GPU") - if get_arch() not in ["gfx950", "gfx1250"]: - pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") - - -# -------------------------------------------------------------------------- -# 1. direct kernel test: (y_vals, y_indx, bitmatrix) -# -------------------------------------------------------------------------- - - -@pytest.mark.parametrize("n_tokens", N_TOKENS) -@pytest.mark.parametrize( - "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES -) -@pytest.mark.parametrize("score_mode, has_bias, renorm, scale", SCORE_COMBOS) -def test_grouped_topk_kernel( - n_tokens, - n_expts_tot, - num_expert_group, - topk_group, - n_expts_act, - score_mode, - has_bias, - renorm, - scale, -): - _maybe_skip() - device = "cuda" - torch.manual_seed(2) - logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) - bias = ( - torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 - if has_bias - else None - ) - - ref_w, ref_ids = _ref_contiguous( - logits.clone(), - n_expts_act, - num_expert_group, - topk_group, - score_mode, - bias, - renorm, - scale, - ) - y_vals, y_indx, bitmatrix = grouped_topk( - logits, - n_expts_act, - num_expert_group=num_expert_group, - topk_group=topk_group, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=scale, - ) - - assert y_vals.shape == (n_tokens, n_expts_act) - assert y_indx.shape == (n_tokens, n_expts_act) - assert y_indx.dtype == torch.int16 - assert y_vals.dtype == logits.dtype - - _assert_selection_matches(ref_ids, ref_w, y_indx, y_vals) - _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot) - - -# -------------------------------------------------------------------------- -# 2. arbitrary (non-contiguous) expert->group mapping -# -------------------------------------------------------------------------- - - -@pytest.mark.parametrize("n_tokens", [8, 64, 1024]) -@pytest.mark.parametrize( - "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES -) -def test_grouped_topk_arbitrary_group( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act -): - _maybe_skip() - device = "cuda" - torch.manual_seed(7) - logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) - bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 - - # Equal-size groups but a shuffled (non-contiguous) expert->group table. - g_size = n_expts_tot // num_expert_group - perm = torch.randperm(n_expts_tot, device=device) - expert_group = torch.empty(n_expts_tot, dtype=torch.int32, device=device) - for g in range(num_expert_group): - expert_group[perm[g * g_size : (g + 1) * g_size]] = g - - score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - ref_w, ref_ids = _ref_arbitrary_grouped( - logits.clone(), - expert_group, - n_expts_act, - num_expert_group, - topk_group, - score_mode, - bias, - renorm, - scale, - ) - y_vals, y_indx, bitmatrix = grouped_topk( - logits, - n_expts_act, - num_expert_group=num_expert_group, - topk_group=topk_group, - expert_group=expert_group, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=scale, - ) - - _assert_selection_matches(ref_ids, ref_w, y_indx, y_vals) - _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot) - - -# -------------------------------------------------------------------------- -# 3. end-to-end routing_a8w4(use_grouped_topk=True) -# -------------------------------------------------------------------------- - - -@pytest.mark.parametrize("n_tokens", [8, 16, 64, 1024]) -@pytest.mark.parametrize( - "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES -) -@pytest.mark.parametrize("block_m", [16, 32]) -def test_routing_a8w4_grouped( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, block_m -): - _maybe_skip() - device = "cuda" - torch.manual_seed(2) - logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) - bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 - score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - - # The selection the kernel makes (deterministic for fixed inputs); used as - # ground truth for the sort/scatter pipeline check. - y_vals, y_indx, _ = grouped_topk( - logits, - n_expts_act, - num_expert_group=num_expert_group, - topk_group=topk_group, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=scale, - ) - - tri_routing_data, tri_gather, tri_scatter = routing_a8w4( - logits, - n_expts_act, - block_m, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=scale, - use_grouped_topk=True, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) - - ref_pack = _sort_and_build_torch( - y_vals.float(), y_indx.to(torch.int32), n_expts_tot, block_m - ) - _check_routing_data_bucket( - ref_pack, tri_routing_data, tri_gather, tri_scatter, y_vals.float(), y_indx - ) - assert tri_routing_data.n_expts_tot == n_expts_tot - assert tri_routing_data.n_expts_act == n_expts_act - assert tri_routing_data.block_m == block_m - - -# -------------------------------------------------------------------------- -# 4. fused shared experts (DeepSeek-R1/V3 always-on shared expert) -# -------------------------------------------------------------------------- - - -@pytest.mark.parametrize("n_tokens", [8, 64, 1024]) -@pytest.mark.parametrize( - "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES -) -@pytest.mark.parametrize("n_shared", [1, 2]) -def test_grouped_topk_shared_expert( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, n_shared -): - """The kernel appends `n_shared` always-on shared experts (id n_expts_tot+i, - weight 1.0) AFTER the routed renorm. The routed portion must still match the - reference, and the shared columns + bitmatrix must reflect the append.""" - _maybe_skip() - device = "cuda" - torch.manual_seed(2) - logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) - bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 - score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - - ref_w, ref_ids = _ref_contiguous( - logits.clone(), - n_expts_act, - num_expert_group, - topk_group, - score_mode, - bias, - renorm, - scale, - ) - y_vals, y_indx, bitmatrix = grouped_topk( - logits, - n_expts_act, - num_expert_group=num_expert_group, - topk_group=topk_group, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=scale, - num_fused_shared_experts=n_shared, - shared_experts_score=1.0, - ) - - assert y_vals.shape == (n_tokens, n_expts_act + n_shared) - assert y_indx.shape == (n_tokens, n_expts_act + n_shared) - - # Routed slots (first n_expts_act) must match the reference selection. - _assert_selection_matches( - ref_ids, ref_w, y_indx[:, :n_expts_act], y_vals[:, :n_expts_act] - ) - - # Shared slots: fixed id n_expts_tot+i, weight 1.0, for every token. - for i in range(n_shared): - ids_i = y_indx[:, n_expts_act + i].cpu().long() - w_i = y_vals[:, n_expts_act + i].float().cpu() - assert torch.all(ids_i == n_expts_tot + i), f"shared id col {i}: {ids_i}" - assert torch.allclose(w_i, torch.ones(n_tokens)), f"shared weight col {i}" - - # Bitmatrix must contain routed + shared selections over the widened width. - _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot + n_shared) - - -@pytest.mark.parametrize("n_tokens", [8, 16, 64, 1024]) -@pytest.mark.parametrize( - "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES -) -@pytest.mark.parametrize("block_m", [16, 32]) -@pytest.mark.parametrize("n_shared", [1, 2]) -def test_routing_a8w4_grouped_shared( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, block_m, n_shared -): - """End-to-end routing_a8w4 with fused shared experts: histogram must include - a full shared bucket (n_tokens) per shared expert and the gather/scatter must - form a valid inverse permutation over the widened gate count.""" - _maybe_skip() - device = "cuda" - torch.manual_seed(2) - logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) - bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 - score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - - rd, gather, scatter = routing_a8w4( - logits, - n_expts_act, - block_m, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=scale, - use_grouped_topk=True, - num_expert_group=num_expert_group, - topk_group=topk_group, - num_fused_shared_experts=n_shared, - ) - - assert rd.n_expts_tot == n_expts_tot + n_shared - assert rd.n_expts_act == n_expts_act + n_shared - - # Every token is routed to each shared expert exactly once. - for i in range(n_shared): - assert rd.expt_hist[n_expts_tot + i].item() == n_tokens - assert rd.expt_hist.sum().item() == n_tokens * (n_expts_act + n_shared) - - n_gates = n_tokens * (n_expts_act + n_shared) - iota = torch.arange(n_gates, dtype=torch.int32, device=gather.device) - assert torch.equal(scatter.long()[gather.long()], iota), "scatter[gather[j]] != j" diff --git a/op_tests/triton_tests/moe/test_moe_routing.py b/op_tests/triton_tests/moe/test_moe_routing.py index 65132a72b3..6abe69a46c 100644 --- a/op_tests/triton_tests/moe/test_moe_routing.py +++ b/op_tests/triton_tests/moe/test_moe_routing.py @@ -3,13 +3,15 @@ import torch.nn.functional as F from aiter.ops.triton.moe.moe_routing.routing import ( routing, - routing_a8w4, + routing_ds, routing_a8w4_from_hash, routing_a8w4_from_topk, routing_torch, compute_expt_data_torch, ) from aiter.ops.triton.utils._triton.arch_info import get_arch +from aiter.ops.topk import biased_grouped_topk_torch, grouped_topk_torch +from aiter.ops.triton.moe.moe_routing.topk import grouped_topk def assert_equal(ref, tri): @@ -313,8 +315,12 @@ def _check_routing_data_bucket( n_expts_tot = ref_hist.numel() # Inverse permutation invariant: gate_indx[topk_indx[j]] == j. - iota = torch.arange(n_gates, dtype=torch.int32, device=tri_gather.device) - assert torch.equal(tri_scatter[tri_gather.long()], iota), "scatter[gather[j]] != j" + # Cast scatter to int64 first: the grouped routing_ds path returns uint16 + # indices, which CUDA cannot advanced-index. + iota = torch.arange(n_gates, dtype=torch.int64, device=tri_gather.device) + assert torch.equal( + tri_scatter.long()[tri_gather.long()], iota + ), "scatter[gather[j]] != j" # Per-expert (token, weight) multisets. flat_ids = topk_ids.reshape(-1).cpu().tolist() @@ -412,7 +418,7 @@ def test_routing_a8w4( renorm=renorm, routed_scaling_factor=routed_scaling_factor, ) - tri_routing_data, tri_gather, tri_scatter = routing_a8w4( + tri_routing_data, tri_gather, tri_scatter = routing_ds( logits, n_expts_act, block_m, @@ -574,6 +580,480 @@ def test_routing_a8w4_from_topk( assert tri_routing_data.block_m == block_m +# ========================================================================== +# grouped-top-k routing (aiter.ops.triton.moe.moe_routing.topk.grouped_topk) +# Moved from test_grouped_topk.py. Reuses the shared helpers above +# (assert_equal, assert_close, init_data, _sort_and_build_torch, +# _check_routing_data_bucket). +# ========================================================================== + + +# -------------------------------------------------------------------------- +# torch references +# -------------------------------------------------------------------------- + + +def _ref_sqrtsoftplus_grouped( + logits, bias, k, num_expert_group, topk_group, renorm, scale +): + """sqrtsoftplus grouped-topk reference (no aiter equivalent exists). + + Mirrors the kernel: sqrt(softplus(logits)) transform, bias added for + SELECTION only, top-2-sum-per-group when biased else per-group max, mask + non-selected groups, top-k on the (biased) choice scores, gather UNBIASED + weights, renorm + scale. + """ + nt, ne = logits.shape + g_size = ne // num_expert_group + transform = torch.sqrt(F.softplus(logits.float())) + choice = transform + bias.float().unsqueeze(0) if bias is not None else transform + + if bias is not None: + group_scores = ( + choice.view(nt, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + group_scores = choice.view(nt, num_expert_group, -1).max(dim=-1).values + + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False).indices + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1.0) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(nt, num_expert_group, g_size) + .reshape(nt, ne) + .bool() + ) + tmp = choice.masked_fill(~score_mask, float("-inf")) + ids = torch.topk(tmp, k=k, dim=-1, sorted=False).indices + w = transform.gather(1, ids) + if renorm: + w = w / (w.sum(dim=-1, keepdim=True) + 1e-20) + w = w * scale + return w.float(), ids.to(torch.int64) + + +def _ref_contiguous( + logits, k, num_expert_group, topk_group, score_mode, bias, renorm, scale +): + """Reference for contiguous DeepSeek group layout. Reuses aiter torch refs + where they apply, plus the sqrtsoftplus wrapper + scale.""" + if score_mode == "sqrtsoftplus": + return _ref_sqrtsoftplus_grouped( + logits, bias, k, num_expert_group, topk_group, renorm, scale + ) + if score_mode == "sigmoid" and bias is not None: + w, ids = biased_grouped_topk_torch( + logits, bias, k, renorm, num_expert_group, topk_group + ) + elif score_mode in ("sigmoid", "softmax"): + w, ids = grouped_topk_torch( + logits, k, renorm, num_expert_group, topk_group, scoring_func=score_mode + ) + else: + raise ValueError(score_mode) + return w.float() * scale, ids.to(torch.int64) + + +def _ref_arbitrary_grouped( + logits, + expert_group, + k, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, +): + """General reference honoring an arbitrary expert->group table (equal-size + groups). Used for the non-contiguous mapping case where the aiter refs + (which assume contiguous .view groups) don't apply.""" + nt, ne = logits.shape + f32 = logits.float() + if score_mode == "softmax": + scores = torch.softmax(f32, dim=-1) + elif score_mode == "sigmoid": + scores = f32.sigmoid() + elif score_mode == "sqrtsoftplus": + scores = torch.sqrt(F.softplus(f32)) + else: + scores = f32 + choice = scores + bias.float().unsqueeze(0) if bias is not None else scores + + group_scores = torch.empty((nt, num_expert_group), device=logits.device) + for g in range(num_expert_group): + cols = (expert_group == g).nonzero(as_tuple=False).flatten() + sub = choice[:, cols] + if bias is not None: + group_scores[:, g] = sub.topk(2, dim=-1)[0].sum(dim=-1) + else: + group_scores[:, g] = sub.max(dim=-1).values + + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False).indices + group_sel = torch.zeros( + (nt, num_expert_group), device=logits.device, dtype=torch.bool + ) + group_sel.scatter_(1, group_idx, True) + # expert keep mask via group table lookup + expert_keep = group_sel[:, expert_group.long()] # (nt, ne) + + tmp = choice.masked_fill(~expert_keep, float("-inf")) + ids = torch.topk(tmp, k=k, dim=-1, sorted=False).indices + w = scores.gather(1, ids) + if renorm: + w = w / (w.sum(dim=-1, keepdim=True) + 1e-20) + w = w * scale + return w.float(), ids.to(torch.int64) + + +# -------------------------------------------------------------------------- +# output comparison utilities +# -------------------------------------------------------------------------- + + +def _row_sort_by_id(ids, weights): + order = torch.argsort(ids, dim=1) + return torch.gather(ids, 1, order), torch.gather(weights, 1, order) + + +def _assert_selection_matches(ref_ids, ref_w, tri_ids, tri_w): + """Set-wise per-row comparison: sort both by expert id, then assert ids + identical and gathered weights close.""" + ref_ids_s, ref_w_s = _row_sort_by_id(ref_ids.cpu(), ref_w.cpu()) + tri_ids_s, tri_w_s = _row_sort_by_id(tri_ids.cpu().long(), tri_w.cpu().float()) + assert torch.equal( + ref_ids_s, tri_ids_s + ), f"selected expert ids differ:\nref={ref_ids_s}\ntri={tri_ids_s}" + assert_close(ref_w_s, tri_w_s, 2e-2, 4e-3, description="weights") + + +def _decode_bitmatrix(bitmatrix, n_tokens, n_expts_tot): + """Decode the packed uint32 Bitmatrix into a (n_tokens, n_expts_tot) bool + matrix of selected experts.""" + data = bitmatrix.data[:n_tokens].to(torch.int64) # (n_tokens, n_cols_words) + n_cols_words = data.shape[1] + bits = torch.arange(32, device=data.device, dtype=torch.int64) + unpacked = ((data.unsqueeze(-1) >> bits) & 1).bool() # (nt, words, 32) + unpacked = unpacked.reshape(n_tokens, n_cols_words * 32) + return unpacked[:, :n_expts_tot] + + +def _assert_bitmatrix_matches(bitmatrix, tri_ids, n_tokens, n_expts_tot): + decoded = _decode_bitmatrix(bitmatrix, n_tokens, n_expts_tot).cpu() + expected = torch.zeros((n_tokens, n_expts_tot), dtype=torch.bool) + expected.scatter_(1, tri_ids.cpu().long(), True) + assert torch.equal(decoded, expected), "bitmatrix does not match selected ids" + + +# -------------------------------------------------------------------------- +# parametrization +# -------------------------------------------------------------------------- + +# (n_expts_tot, num_expert_group, topk_group, n_expts_act) — DeepSeek-like. +GROUP_SHAPES = [ + (256, 8, 4, 8), + (128, 8, 4, 6), +] +# n_tokens spanning the fused (<=16) and regular sort_tokens paths. +GROUPED_N_TOKENS = [8, 16, 64, 1024] +# (score_mode, has_bias, renorm, routed_scaling_factor) — production-core set. +SCORE_COMBOS = [ + ("sqrtsoftplus", True, True, 2.5), + ("sigmoid", True, True, 1.0), + ("softmax", False, False, 1.0), +] + + +def _maybe_skip(): + if not torch.cuda.is_available(): + pytest.skip("grouped_topk requires a GPU") + if get_arch() not in ["gfx950", "gfx1250"]: + pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") + + +# -------------------------------------------------------------------------- +# 1. direct kernel test: (y_vals, y_indx, bitmatrix) +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_tokens", GROUPED_N_TOKENS) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +@pytest.mark.parametrize("score_mode, has_bias, renorm, scale", SCORE_COMBOS) +def test_grouped_topk_kernel( + n_tokens, + n_expts_tot, + num_expert_group, + topk_group, + n_expts_act, + score_mode, + has_bias, + renorm, + scale, +): + _maybe_skip() + device = "cuda" + torch.manual_seed(2) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = ( + torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + if has_bias + else None + ) + + ref_w, ref_ids = _ref_contiguous( + logits.clone(), + n_expts_act, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, + ) + y_vals, y_indx, bitmatrix = grouped_topk( + logits, + n_expts_act, + num_expert_group=num_expert_group, + topk_group=topk_group, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + ) + + assert y_vals.shape == (n_tokens, n_expts_act) + assert y_indx.shape == (n_tokens, n_expts_act) + assert y_indx.dtype == torch.int16 + assert y_vals.dtype == logits.dtype + + _assert_selection_matches(ref_ids, ref_w, y_indx, y_vals) + _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot) + + +# -------------------------------------------------------------------------- +# 2. arbitrary (non-contiguous) expert->group mapping +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_tokens", [8, 64, 1024]) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +def test_grouped_topk_arbitrary_group( + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act +): + _maybe_skip() + device = "cuda" + torch.manual_seed(7) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + + # Equal-size groups but a shuffled (non-contiguous) expert->group table. + g_size = n_expts_tot // num_expert_group + perm = torch.randperm(n_expts_tot, device=device) + expert_group = torch.empty(n_expts_tot, dtype=torch.int32, device=device) + for g in range(num_expert_group): + expert_group[perm[g * g_size : (g + 1) * g_size]] = g + + score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 + ref_w, ref_ids = _ref_arbitrary_grouped( + logits.clone(), + expert_group, + n_expts_act, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, + ) + y_vals, y_indx, bitmatrix = grouped_topk( + logits, + n_expts_act, + num_expert_group=num_expert_group, + topk_group=topk_group, + expert_group=expert_group, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + ) + + _assert_selection_matches(ref_ids, ref_w, y_indx, y_vals) + _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot) + + +# -------------------------------------------------------------------------- +# 3. end-to-end routing_ds(use_grouped_topk=True) +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_tokens", [8, 16, 64, 1024]) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +@pytest.mark.parametrize("block_m", [16, 32]) +def test_routing_ds_grouped( + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, block_m +): + _maybe_skip() + device = "cuda" + torch.manual_seed(2) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 + + # The selection the kernel makes (deterministic for fixed inputs); used as + # ground truth for the sort/scatter pipeline check. + y_vals, y_indx, _ = grouped_topk( + logits, + n_expts_act, + num_expert_group=num_expert_group, + topk_group=topk_group, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + ) + + tri_routing_data, tri_gather, tri_scatter = routing_ds( + logits, + n_expts_act, + block_m, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + use_grouped_topk=True, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + + ref_pack = _sort_and_build_torch( + y_vals.float(), y_indx.to(torch.int32), n_expts_tot, block_m + ) + _check_routing_data_bucket( + ref_pack, tri_routing_data, tri_gather, tri_scatter, y_vals.float(), y_indx + ) + assert tri_routing_data.n_expts_tot == n_expts_tot + assert tri_routing_data.n_expts_act == n_expts_act + assert tri_routing_data.block_m == block_m + + +# -------------------------------------------------------------------------- +# 4. fused shared experts (DeepSeek-R1/V3 always-on shared expert) +# -------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_tokens", [8, 64, 1024]) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +@pytest.mark.parametrize("n_shared", [1, 2]) +def test_grouped_topk_shared_expert( + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, n_shared +): + """The kernel appends `n_shared` always-on shared experts (id n_expts_tot+i, + weight 1.0) AFTER the routed renorm. The routed portion must still match the + reference, and the shared columns + bitmatrix must reflect the append.""" + _maybe_skip() + device = "cuda" + torch.manual_seed(2) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 + + ref_w, ref_ids = _ref_contiguous( + logits.clone(), + n_expts_act, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, + ) + y_vals, y_indx, bitmatrix = grouped_topk( + logits, + n_expts_act, + num_expert_group=num_expert_group, + topk_group=topk_group, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + num_fused_shared_experts=n_shared, + shared_experts_score=1.0, + ) + + assert y_vals.shape == (n_tokens, n_expts_act + n_shared) + assert y_indx.shape == (n_tokens, n_expts_act + n_shared) + + # Routed slots (first n_expts_act) must match the reference selection. + _assert_selection_matches( + ref_ids, ref_w, y_indx[:, :n_expts_act], y_vals[:, :n_expts_act] + ) + + # Shared slots: fixed id n_expts_tot+i, weight 1.0, for every token. + for i in range(n_shared): + ids_i = y_indx[:, n_expts_act + i].cpu().long() + w_i = y_vals[:, n_expts_act + i].float().cpu() + assert torch.all(ids_i == n_expts_tot + i), f"shared id col {i}: {ids_i}" + assert torch.allclose(w_i, torch.ones(n_tokens)), f"shared weight col {i}" + + # Bitmatrix must contain routed + shared selections over the widened width. + _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot + n_shared) + + +@pytest.mark.parametrize("n_tokens", [8, 16, 64, 1024]) +@pytest.mark.parametrize( + "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES +) +@pytest.mark.parametrize("block_m", [16, 32]) +@pytest.mark.parametrize("n_shared", [1, 2]) +def test_routing_ds_grouped_shared( + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, block_m, n_shared +): + """End-to-end routing_ds with fused shared experts: histogram must include + a full shared bucket (n_tokens) per shared expert and the gather/scatter must + form a valid inverse permutation over the widened gate count.""" + _maybe_skip() + device = "cuda" + torch.manual_seed(2) + logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 + score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 + + rd, gather, scatter = routing_ds( + logits, + n_expts_act, + block_m, + score_mode=score_mode, + bias=bias, + renorm=renorm, + routed_scaling_factor=scale, + use_grouped_topk=True, + num_expert_group=num_expert_group, + topk_group=topk_group, + num_fused_shared_experts=n_shared, + ) + + assert rd.n_expts_tot == n_expts_tot + n_shared + assert rd.n_expts_act == n_expts_act + n_shared + + # Every token is routed to each shared expert exactly once. + for i in range(n_shared): + assert rd.expt_hist[n_expts_tot + i].item() == n_tokens + assert rd.expt_hist.sum().item() == n_tokens * (n_expts_act + n_shared) + + n_gates = n_tokens * (n_expts_act + n_shared) + iota = torch.arange(n_gates, dtype=torch.int32, device=gather.device) + assert torch.equal(scatter.long()[gather.long()], iota), "scatter[gather[j]] != j" + + def bench_routing(): import triton.profiler as proton From cadb70851e6fd0302ee2f37962bc290bf7f7feff Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Wed, 3 Jun 2026 16:20:41 +0000 Subject: [PATCH 07/18] ruff --- aiter/ops/triton/moe/moe_op_gemm_a8w4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py index 96d742bfcb..bda2efc0c7 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py @@ -5,7 +5,6 @@ import itertools import os import json -import json import torch import triton from aiter.ops.triton.moe.moe_routing.routing import RoutingData From 8066312498bfa810439f4548971dc3f34c1a09cf Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Wed, 3 Jun 2026 16:29:53 +0000 Subject: [PATCH 08/18] black formatting --- .../_triton_kernels/moe/moe_routing/topk.py | 55 +++++++++---------- aiter/ops/triton/moe/moe_routing/routing.py | 6 +- aiter/ops/triton/moe/moe_routing/topk.py | 41 ++++++++------ 3 files changed, 52 insertions(+), 50 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py index 86abb99b52..014fdab791 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py @@ -428,24 +428,23 @@ def _hash_routing( tl.store(BitsPtrs, r, mask=mask_m[:, None]) - @triton.jit def _grouped_topk( - X, # router logits [n_rows, n_expts_tot] (bf16/fp32) + X, # router logits [n_rows, n_expts_tot] (bf16/fp32) stride_xm, - ExpertGroup, # int32 [n_expts_tot] expert→group_id - Yv, # [n_rows, N_EXPTS_ACT_PAD] selected weights - Yi, # [n_rows, N_EXPTS_ACT_PAD] selected expert ids (int16) + ExpertGroup, # int32 [n_expts_tot] expert→group_id + Yv, # [n_rows, N_EXPTS_ACT_PAD] selected weights + Yi, # [n_rows, N_EXPTS_ACT_PAD] selected expert ids (int16) stride_ym, - Bits, # bitmatrix data + Bits, # bitmatrix data stride_rm, stride_rn, n_rows, n_expts_tot, - S, # bitmatrix scratchpad — must memset to 0 + S, # bitmatrix scratchpad — must memset to 0 BLOCK_S: tl.constexpr, s_blocks, - SP, # bitmatrix partials — must memset to 0 + SP, # bitmatrix partials — must memset to 0 BLOCK_SP: tl.constexpr, sp_blocks, sp_size, @@ -528,7 +527,7 @@ def _grouped_topk( # -- 4. Per-group reduction over arbitrary expert→group mapping. gid = tl.load(ExpertGroup + offs_n, mask=mask_n, other=0).to(tl.int32) g_arange = tl.arange(0, NUM_EXPERT_GROUP) - gid_eq = gid[:, None] == g_arange[None, :] # [BLOCK_N, NUM_EXPERT_GROUP] + gid_eq = gid[:, None] == g_arange[None, :] # [BLOCK_N, NUM_EXPERT_GROUP] # 3-D one-hot expand: [BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP], with -inf # outside each group's column. @@ -536,7 +535,7 @@ def _grouped_topk( BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP ) expanded = tl.where(gid_eq[None, :, :], sfc_3d, float("-inf")) - group_max1 = tl.max(expanded, axis=1) # [BLOCK_M, NUM_EXPERT_GROUP] + group_max1 = tl.max(expanded, axis=1) # [BLOCK_M, NUM_EXPERT_GROUP] if HAS_BIAS: # Top-2-sum-per-group. To find the second-largest score per group @@ -545,13 +544,11 @@ def _grouped_topk( gm1_per_e = tl.sum( gid_eq[None, :, :].to(tl.float32) * group_max1[:, None, :], axis=2, - ) # [BLOCK_M, BLOCK_N] + ) # [BLOCK_M, BLOCK_N] suppressed = tl.where( scores_for_choice == gm1_per_e, float("-inf"), scores_for_choice ) - sup_3d = suppressed[:, :, None].broadcast_to( - BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP - ) + sup_3d = suppressed[:, :, None].broadcast_to(BLOCK_M, BLOCK_N, NUM_EXPERT_GROUP) expanded2 = tl.where(gid_eq[None, :, :], sup_3d, float("-inf")) group_max2 = tl.max(expanded2, axis=1) group_scores = group_max1 + group_max2 @@ -563,17 +560,20 @@ def _grouped_topk( group_mask_i = tl.zeros([BLOCK_M, NUM_EXPERT_GROUP], dtype=tl.int32) gs = group_scores for _gj in tl.static_range(TOPK_GROUP): - am_g = tl.argmax(gs, axis=1).to(tl.int32) # [BLOCK_M] - sel_g = (g_arange[None, :] == am_g[:, None]) # [BLOCK_M, NUM_EXPERT_GROUP] + am_g = tl.argmax(gs, axis=1).to(tl.int32) # [BLOCK_M] + sel_g = g_arange[None, :] == am_g[:, None] # [BLOCK_M, NUM_EXPERT_GROUP] group_mask_i = group_mask_i | sel_g.to(tl.int32) gs = tl.where(sel_g, float("-inf"), gs) # -- 6. Per-(token, expert) keep-mask via group-id lookup, then suppress # experts in non-selected groups on the bias-augmented scores. - expert_keep = tl.sum( - gid_eq[None, :, :].to(tl.int32) * group_mask_i[:, None, :], - axis=2, - ) > 0 # [BLOCK_M, BLOCK_N] + expert_keep = ( + tl.sum( + gid_eq[None, :, :].to(tl.int32) * group_mask_i[:, None, :], + axis=2, + ) + > 0 + ) # [BLOCK_M, BLOCK_N] sfc_masked = tl.where(expert_keep, scores_for_choice, float("-inf")) # -- 7. Per-expert top-``N_EXPTS_ACT`` via repeated argmax. Padded slots @@ -583,17 +583,15 @@ def _grouped_topk( y_indices = tl.zeros([BLOCK_M, N_EXPTS_ACT_PAD], dtype=tl.int32) sfc_iter = sfc_masked for kj in tl.static_range(N_EXPTS_ACT): - am_k = tl.argmax(sfc_iter, axis=1).to(tl.int32) # [BLOCK_M] + am_k = tl.argmax(sfc_iter, axis=1).to(tl.int32) # [BLOCK_M] slot_eq = (tl.arange(0, N_EXPTS_ACT_PAD) == kj)[None, :] y_indices = tl.where(slot_eq, am_k[:, None], y_indices) - sfc_iter = tl.where( - n_arange[None, :] == am_k[:, None], float("-inf"), sfc_iter - ) + sfc_iter = tl.where(n_arange[None, :] == am_k[:, None], float("-inf"), sfc_iter) # -- 8. Gather UNBIASED weights at selected indices. pos_eq = ( n_arange[None, None, :] == y_indices[:, :, None] - ) # [BLOCK_M, K_PAD, BLOCK_N] + ) # [BLOCK_M, K_PAD, BLOCK_N] scores_3d = scores[:, None, :].broadcast_to(BLOCK_M, N_EXPTS_ACT_PAD, BLOCK_N) y_weights = tl.sum(tl.where(pos_eq, scores_3d, 0.0), axis=2) # [BLOCK_M, K_PAD] @@ -641,7 +639,7 @@ def _grouped_topk( safe_idx = tl.where(real_mask, y_indices, 0).to(tl.uint32) y_div = safe_idx // 32 y_rem = safe_idx % 32 - bm_iters: tl.constexpr = N_EXPTS_PAD // BLOCK_N # = 1 (single-block) + bm_iters: tl.constexpr = N_EXPTS_PAD // BLOCK_N # = 1 (single-block) for i in range(bm_iters): offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32) y2 = tl.where( @@ -650,8 +648,5 @@ def _grouped_topk( 0, ) r = tl.reduce_or(y2, axis=1) - BitsPtrs = ( - Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn - ) + BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn tl.store(BitsPtrs, r, mask=mask_m[:, None]) - diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index 8b7741d130..de1c577b83 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -387,9 +387,9 @@ def routing_ds( # Routed top-k + appended shared experts per token. n_expts_act = n_expts_act + n_shared else: - assert n_shared == 0, ( - "fused shared experts are only supported on the grouped-topk path" - ) + assert ( + n_shared == 0 + ), "fused shared experts are only supported on the grouped-topk path" from .topk import topk expt_scal, expt_indx, bitmatrix = topk( diff --git a/aiter/ops/triton/moe/moe_routing/topk.py b/aiter/ops/triton/moe/moe_routing/topk.py index a42be56f6b..a5627dd577 100644 --- a/aiter/ops/triton/moe/moe_routing/topk.py +++ b/aiter/ops/triton/moe/moe_routing/topk.py @@ -1,6 +1,10 @@ import triton import torch -from aiter.ops.triton._triton_kernels.moe.moe_routing.topk import _topk, _hash_routing, _grouped_topk +from aiter.ops.triton._triton_kernels.moe.moe_routing.topk import ( + _topk, + _hash_routing, + _grouped_topk, +) from aiter.ops.triton.moe.moe_routing.bitmatrix import Bitmatrix @@ -11,7 +15,7 @@ def grouped_topk( topk_group: int, *, expert_group: torch.Tensor | None = None, - apply_softmax: bool = False, # accepted for parity with topk(); ignored + apply_softmax: bool = False, # accepted for parity with topk(); ignored HIST_BLOCK_M: int = 32, score_mode: str = "softmax", bias: torch.Tensor | None = None, @@ -43,31 +47,35 @@ def grouped_topk( """ assert x.dim() == 2 n_rows, n_cols = x.shape - assert n_cols <= 256, ( - f"DeepSeek-class envelope: n_expts_tot ({n_cols}) must be <= 256" - ) + assert ( + n_cols <= 256 + ), f"DeepSeek-class envelope: n_expts_tot ({n_cols}) must be <= 256" # Fused shared experts are appended (always-on) AFTER the routed selection; # they occupy expert ids [n_cols, n_cols + num_fused_shared_experts). n_shared = num_fused_shared_experts assert n_shared >= 0 - n_total = n_cols + n_shared # experts incl. shared (bitmatrix width) - k_out = k + n_shared # output width (routed top-k + shared) + n_total = n_cols + n_shared # experts incl. shared (bitmatrix width) + k_out = k + n_shared # output width (routed top-k + shared) assert num_expert_group > 1 - assert num_expert_group <= 16, ( - f"NUM_EXPERT_GROUP ({num_expert_group}) > 16 not supported" - ) + assert ( + num_expert_group <= 16 + ), f"NUM_EXPERT_GROUP ({num_expert_group}) > 16 not supported" assert 0 < topk_group <= num_expert_group assert 0 < k <= 16 - assert score_mode in ("softmax", "sigmoid", "sqrtsoftplus", "none"), ( - f"unknown score_mode {score_mode!r}" - ) + assert score_mode in ( + "softmax", + "sigmoid", + "sqrtsoftplus", + "none", + ), f"unknown score_mode {score_mode!r}" has_bias = bias is not None if has_bias: assert bias.dim() == 1 and bias.shape[0] == n_cols assert bias.dtype == torch.float32 - assert score_mode in ("sqrtsoftplus", "sigmoid"), ( - "bias only supported with sqrtsoftplus / sigmoid" - ) + assert score_mode in ( + "sqrtsoftplus", + "sigmoid", + ), "bias only supported with sqrtsoftplus / sigmoid" dev = x.device @@ -173,7 +181,6 @@ def grouped_topk( return y_vals, y_indx, bitmatrix - def topk( x, k, From 8c4cd018513fb22239fe43383d608438cab44ceb Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Wed, 3 Jun 2026 21:13:02 +0000 Subject: [PATCH 09/18] shift routing to unified routing function --- aiter/ops/triton/moe/moe_routing/routing.py | 159 ++++++++---------- op_tests/triton_tests/moe/test_moe_routing.py | 46 +++-- 2 files changed, 97 insertions(+), 108 deletions(-) diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index de1c577b83..b0979ea59c 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -257,67 +257,12 @@ def pad(x): # -------------------------- -def routing(logits, n_expts_act, sm_first=False): - HIST_BLOCK_M = 32 - - from .topk import topk - - if sm_first: - logits = torch.softmax(logits, dim=-1) - expt_scal, expt_indx, bitmatrix = topk( - logits, - n_expts_act, - apply_softmax=not sm_first, - HIST_BLOCK_M=HIST_BLOCK_M, - ) - - num_tokens, n_expts_tot = logits.shape - m = num_tokens * n_expts_act - tokens_per_expt = max(1, m // n_expts_tot) - block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) - if num_tokens <= 16: - HIST_BLOCK_M = triton.next_power_of_2(num_tokens) - ( - hist, - topk_indx, - gate_indx, - gate_scal, - token_offs_raw, - token_offs_pad, - block_pid_map, - ) = sort_tokens_fused( - expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M - ) - else: - ( - hist, - topk_indx, - gate_indx, - gate_scal, - token_offs_raw, - token_offs_pad, - block_pid_map, - ) = sort_tokens( - expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M - ) - expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) - - # pack the matmul data structure - gather_indx = topk_indx - scatter_indx = gate_indx - return ( - RoutingData(block_m, gate_scal, hist, n_expts_tot, n_expts_act, expt_data), - gather_indx, - scatter_indx, - ) - - -def routing_ds( +def routing( logits: torch.Tensor, n_expts_act: int, - block_m: int, *, - score_mode: str = "sqrtsoftplus", + score_mode: str | None = None, + sm_first: bool = False, bias: torch.Tensor | None = None, renorm: bool = True, routed_scaling_factor: float = 1.0, @@ -328,45 +273,78 @@ def routing_ds( num_fused_shared_experts: int = 0, shared_experts_score: float = 1.0, ): - """All-Triton routing for the a8w4 path: fused V4 routing math + sort. - - One-shot pipeline: - 1. aiter `_topk` (extended): pre-transform (sqrtsoftplus) + bias + topk - + bitmatrix + renorm + scale — single Triton kernel. - 2. aiter `sort_tokens` (or `sort_tokens_fused` for tiny M): sort tokens by - expert and produce ExptData specialized for the given ``block_m``. - - When ``use_grouped_topk=True``, step 1 is replaced by ATOM's single-fused - Triton ``grouped_topk`` kernel - (``atom.model_ops.grouped_topk.grouped_topk``) — DeepSeek-V2/V3-style - hierarchical routing (pick ``topk_group`` groups out of - ``num_expert_group``, then top-``n_expts_act`` experts within those - groups). Same return contract as ``topk`` (y_vals, y_indx, Bitmatrix), so - ``sort_tokens`` / ``sort_tokens_fused`` consume it unchanged. - - Returns (RoutingData, gather_indx, scatter_indx) where gather_indx and - scatter_indx are raw int32 tensors (no GatherIndx/ScatterIndx wrappers) — - consumed directly by ``moe_gemm_a8w4``. - - No multi-block_m dict, no triton_kernels wrapper, no Python bridge step. + """Routing entry point. ``score_mode`` selects the path: + + * ``score_mode is None`` (default) -> the plain flat top-k routing process: + flat top-k with softmax. ``sm_first`` controls whether softmax is applied + to the logits before the top-k (``True``) or inside the top-k + (``False``). The fused-V4-only arguments (``bias``, ``use_grouped_topk``, + ``num_fused_shared_experts`` ...) are ignored on this path. + * ``score_mode is not None`` -> the fused V4 (DeepSeek) routing process: + fused score transform + (optionally grouped) top-k + fused shared + experts. ``sm_first`` is ignored on this path. + + ``block_m`` is not supplied by the caller: it is derived internally from the + raw ``logits`` shape and the originally requested ``n_expts_act`` (before + any shared-expert widening). + + Returns ``(RoutingData, gather_indx, scatter_indx)``. """ - n_tokens, n_routed = logits.shape + num_tokens, n_routed = logits.shape - # Fused shared experts are appended (always-on) to every token by the - # grouped-topk kernel, occupying expert ids [n_routed, n_routed + n_shared). - # They widen both the per-token selection (n_expts_act) and the total - # expert count used for the sort / histogram. + # block_m heuristic from the raw logits shape and the originally requested + # n_expts_act. + m = num_tokens * n_expts_act + tokens_per_expt = max(1, m // n_routed) + block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + + # ------------------------------------------------------------------ + # flat top-k path: plain top-k + softmax (score_mode is None) + # ------------------------------------------------------------------ + if score_mode is None: + from .topk import topk + + HIST_BLOCK_M = 32 + if sm_first: + logits = torch.softmax(logits, dim=-1) + expt_scal, expt_indx, bitmatrix = topk( + logits, + n_expts_act, + apply_softmax=not sm_first, + HIST_BLOCK_M=HIST_BLOCK_M, + ) + n_expts_tot = n_routed + if num_tokens <= 16: + HIST_BLOCK_M = triton.next_power_of_2(num_tokens) + sort_fn = sort_tokens_fused + else: + sort_fn = sort_tokens + ( + hist, + topk_indx, + gate_indx, + gate_scal, + token_offs_raw, + token_offs_pad, + block_pid_map, + ) = sort_fn(expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M) + expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) + return ( + RoutingData(block_m, gate_scal, hist, n_expts_tot, n_expts_act, expt_data), + topk_indx, + gate_indx, + ) + + # ------------------------------------------------------------------ + # fused V4 path: fused routing math + sort (score_mode given) + # ------------------------------------------------------------------ n_shared = num_fused_shared_experts n_expts_tot = n_routed + n_shared - # Step 1: per-token expert selection. Either flat top-k (existing aiter - # _topk kernel) or grouped top-k (ATOM's _grouped_topk kernel) — both - # return (y_vals, y_indx, Bitmatrix) with the same downstream contract. if use_grouped_topk and num_expert_group != 1: assert ( num_expert_group is not None and topk_group is not None ), "use_grouped_topk requires num_expert_group and topk_group" - # Lazy import: ATOM-side kernel; avoids hard aiter→atom import order. from aiter.ops.triton.moe.moe_routing.topk import grouped_topk expt_scal, expt_indx, bitmatrix = grouped_topk( @@ -403,9 +381,8 @@ def routing_ds( HIST_BLOCK_M=32, ) - # Step 2: sort tokens by expert and build ExptData for the chosen block_m. - if n_tokens <= 16: - HIST_BLOCK_M = triton.next_power_of_2(max(n_tokens, 1)) + if num_tokens <= 16: + HIST_BLOCK_M = triton.next_power_of_2(max(num_tokens, 1)) sort_fn = sort_tokens_fused else: HIST_BLOCK_M = 32 diff --git a/op_tests/triton_tests/moe/test_moe_routing.py b/op_tests/triton_tests/moe/test_moe_routing.py index 6abe69a46c..2821047da5 100644 --- a/op_tests/triton_tests/moe/test_moe_routing.py +++ b/op_tests/triton_tests/moe/test_moe_routing.py @@ -1,14 +1,26 @@ import pytest import torch +import triton import torch.nn.functional as F from aiter.ops.triton.moe.moe_routing.routing import ( routing, - routing_ds, routing_a8w4_from_hash, routing_a8w4_from_topk, routing_torch, compute_expt_data_torch, ) + + +def _routing_block_m(n_tokens, n_expts_act, n_expts_tot): + """block_m heuristic used by `routing`. + + Uses the raw logits shape and the originally requested n_expts_act (before + any shared-expert widening), exactly as `routing` does internally. + """ + tokens_per_expt = max(1, (n_tokens * n_expts_act) // n_expts_tot) + return max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + + from aiter.ops.triton.utils._triton.arch_info import get_arch from aiter.ops.topk import biased_grouped_topk_torch, grouped_topk_torch from aiter.ops.triton.moe.moe_routing.topk import grouped_topk @@ -125,7 +137,7 @@ def test_routing(n_tokens, n_expts_tot, n_expts_act, sm_first): ref_logits, n_expts_act, sm_first ) tri_routing_data, tri_gather, tri_scatter = routing( - tri_logits, n_expts_act, sm_first + tri_logits, n_expts_act, sm_first=sm_first ) def _assert_indx_equal(ref, tri): @@ -386,7 +398,6 @@ def _check_routing_data_bucket( ("softmax", False, False, 1.0), # identity transform, no renorm ], ) -@pytest.mark.parametrize("block_m", [16, 32]) def test_routing_a8w4( n_tokens, n_expts_tot, @@ -395,7 +406,6 @@ def test_routing_a8w4( has_bias, renorm, routed_scaling_factor, - block_m, ): if get_arch() not in ["gfx950", "gfx1250"]: pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") @@ -409,6 +419,9 @@ def test_routing_a8w4( else None ) + # routing derives block_m internally; mirror that here for the ref. + block_m = _routing_block_m(n_tokens, n_expts_act, n_expts_tot) + ref_pack = routing_a8w4_torch( logits.clone(), n_expts_act, @@ -418,10 +431,9 @@ def test_routing_a8w4( renorm=renorm, routed_scaling_factor=routed_scaling_factor, ) - tri_routing_data, tri_gather, tri_scatter = routing_ds( + tri_routing_data, tri_gather, tri_scatter = routing( logits, n_expts_act, - block_m, score_mode=score_mode, bias=bias, renorm=renorm, @@ -895,9 +907,8 @@ def test_grouped_topk_arbitrary_group( @pytest.mark.parametrize( "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES ) -@pytest.mark.parametrize("block_m", [16, 32]) def test_routing_ds_grouped( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, block_m + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act ): _maybe_skip() device = "cuda" @@ -906,6 +917,9 @@ def test_routing_ds_grouped( bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 + # routing derives block_m internally; mirror that here for the ref. + block_m = _routing_block_m(n_tokens, n_expts_act, n_expts_tot) + # The selection the kernel makes (deterministic for fixed inputs); used as # ground truth for the sort/scatter pipeline check. y_vals, y_indx, _ = grouped_topk( @@ -919,10 +933,9 @@ def test_routing_ds_grouped( routed_scaling_factor=scale, ) - tri_routing_data, tri_gather, tri_scatter = routing_ds( + tri_routing_data, tri_gather, tri_scatter = routing( logits, n_expts_act, - block_m, score_mode=score_mode, bias=bias, renorm=renorm, @@ -1012,14 +1025,14 @@ def test_grouped_topk_shared_expert( @pytest.mark.parametrize( "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES ) -@pytest.mark.parametrize("block_m", [16, 32]) @pytest.mark.parametrize("n_shared", [1, 2]) def test_routing_ds_grouped_shared( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, block_m, n_shared + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, n_shared ): - """End-to-end routing_ds with fused shared experts: histogram must include - a full shared bucket (n_tokens) per shared expert and the gather/scatter must - form a valid inverse permutation over the widened gate count.""" + """End-to-end routing with fused shared experts: histogram must + include a full shared bucket (n_tokens) per shared expert and the + gather/scatter must form a valid inverse permutation over the widened gate + count.""" _maybe_skip() device = "cuda" torch.manual_seed(2) @@ -1027,10 +1040,9 @@ def test_routing_ds_grouped_shared( bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - rd, gather, scatter = routing_ds( + rd, gather, scatter = routing( logits, n_expts_act, - block_m, score_mode=score_mode, bias=bias, renorm=renorm, From 08eba11a0e5065674c0358fea24101d6764bbcfa Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Wed, 3 Jun 2026 21:26:49 +0000 Subject: [PATCH 10/18] ruff --- op_tests/triton_tests/moe/test_moe_routing.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/op_tests/triton_tests/moe/test_moe_routing.py b/op_tests/triton_tests/moe/test_moe_routing.py index 2821047da5..015d07497b 100644 --- a/op_tests/triton_tests/moe/test_moe_routing.py +++ b/op_tests/triton_tests/moe/test_moe_routing.py @@ -9,6 +9,9 @@ routing_torch, compute_expt_data_torch, ) +from aiter.ops.triton.utils._triton.arch_info import get_arch +from aiter.ops.topk import biased_grouped_topk_torch, grouped_topk_torch +from aiter.ops.triton.moe.moe_routing.topk import grouped_topk def _routing_block_m(n_tokens, n_expts_act, n_expts_tot): @@ -21,11 +24,6 @@ def _routing_block_m(n_tokens, n_expts_act, n_expts_tot): return max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) -from aiter.ops.triton.utils._triton.arch_info import get_arch -from aiter.ops.topk import biased_grouped_topk_torch, grouped_topk_torch -from aiter.ops.triton.moe.moe_routing.topk import grouped_topk - - def assert_equal(ref, tri): if isinstance(ref, torch.Tensor): # CI may be failing using this: From 4b63959dfdf4a66b43a754ae4141facaf5be01c5 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Thu, 4 Jun 2026 15:35:33 +0000 Subject: [PATCH 11/18] comment fixes --- aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py | 2 +- aiter/ops/triton/moe/moe_routing/topk.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py index 014fdab791..7da43d1e54 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py @@ -481,7 +481,7 @@ def _grouped_topk( tl.static_assert(BLOCK_N % 32 == 0) tl.static_assert( N_EXPTS_PAD == BLOCK_N, - "DeepSeek-class envelope: BLOCK_N must equal N_EXPTS_PAD (single-block).", + "grouped topk BLOCK_N must equal N_EXPTS_PAD (single-block).", ) x_dtype: tl.constexpr = X.dtype.element_ty diff --git a/aiter/ops/triton/moe/moe_routing/topk.py b/aiter/ops/triton/moe/moe_routing/topk.py index a5627dd577..48d388f919 100644 --- a/aiter/ops/triton/moe/moe_routing/topk.py +++ b/aiter/ops/triton/moe/moe_routing/topk.py @@ -49,7 +49,7 @@ def grouped_topk( n_rows, n_cols = x.shape assert ( n_cols <= 256 - ), f"DeepSeek-class envelope: n_expts_tot ({n_cols}) must be <= 256" + ), f"grouped_topk n_expts_tot ({n_cols}) only supported <= 256" # Fused shared experts are appended (always-on) AFTER the routed selection; # they occupy expert ids [n_cols, n_cols + num_fused_shared_experts). n_shared = num_fused_shared_experts From 986245769a43ef7a82421bf1143ae06c36fab559 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Thu, 4 Jun 2026 16:05:00 +0000 Subject: [PATCH 12/18] black --- aiter/ops/triton/moe/moe_routing/topk.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aiter/ops/triton/moe/moe_routing/topk.py b/aiter/ops/triton/moe/moe_routing/topk.py index 48d388f919..8f2097122c 100644 --- a/aiter/ops/triton/moe/moe_routing/topk.py +++ b/aiter/ops/triton/moe/moe_routing/topk.py @@ -47,9 +47,7 @@ def grouped_topk( """ assert x.dim() == 2 n_rows, n_cols = x.shape - assert ( - n_cols <= 256 - ), f"grouped_topk n_expts_tot ({n_cols}) only supported <= 256" + assert n_cols <= 256, f"grouped_topk n_expts_tot ({n_cols}) only supported <= 256" # Fused shared experts are appended (always-on) AFTER the routed selection; # they occupy expert ids [n_cols, n_cols + num_fused_shared_experts). n_shared = num_fused_shared_experts From 33e5bd1e81983854aafd79e308e067e25c864a28 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Thu, 4 Jun 2026 17:42:17 +0000 Subject: [PATCH 13/18] verified triton kernel is functionally the same as the one needed in atom. small change to import in routing --- aiter/ops/triton/moe/moe_routing/routing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index b0979ea59c..fd50e91f86 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -12,6 +12,7 @@ _expt_data_only_kernel, ) from aiter.ops.triton.utils._triton.arch_info import is_tdm_avail +from aiter.ops.triton.moe.moe_routing.topk import grouped_topk @dataclass @@ -345,7 +346,6 @@ def routing( assert ( num_expert_group is not None and topk_group is not None ), "use_grouped_topk requires num_expert_group and topk_group" - from aiter.ops.triton.moe.moe_routing.topk import grouped_topk expt_scal, expt_indx, bitmatrix = grouped_topk( logits, From 4a00f5c5911d3be4d192adbcc9b7debdadf95a08 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Thu, 4 Jun 2026 22:30:49 +0000 Subject: [PATCH 14/18] review adjustments --- aiter/ops/triton/moe/moe_routing/routing.py | 61 +-- .../fusions/test_fused_routing_from_topk.py | 287 ------------ op_tests/triton_tests/moe/test_moe_routing.py | 442 +++++++----------- 3 files changed, 158 insertions(+), 632 deletions(-) delete mode 100644 op_tests/triton_tests/fusions/test_fused_routing_from_topk.py diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index fd50e91f86..590cac034d 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -408,7 +408,7 @@ def routing( return routing_data, topk_indx, gate_indx -def routing_a8w4_from_hash( +def routing_from_hash( router_logits: torch.Tensor, tid2eid: torch.Tensor, input_ids: torch.Tensor, @@ -470,65 +470,6 @@ def routing_a8w4_from_hash( return routing_data, topk_indx, gate_indx -def routing_a8w4_from_topk( - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - n_expts_tot: int, - block_m: int, -): - """Routing for the a8w4 path when topk has been pre-computed externally - (e.g. DeepSeek-V4 hash layers with tid2eid lookup). - - Mirrors ``routing_a8w4`` but skips the score+topk math step. Pipeline: - 1. aiter ``fused_routing_from_topk``: 3-kernel counting-sort over the - supplied ``(topk_weights, topk_ids)``. Allocates only via - ``torch.empty`` — no histogram memset. - 2. aiter ``_expt_data_only_kernel``: standalone stage1+stage2 launch - that materialises ExptData (token_offs_raw, token_offs_pad, - block_pid_map) from the histogram for the chosen ``block_m``. - - Returns ``(RoutingData, gather_indx, scatter_indx)`` where ``gather_indx`` - and ``scatter_indx`` are raw int32 tensors — same contract as - ``routing_a8w4`` — so ``_a8w4_fused_experts`` consumes them unchanged. - """ - - n_tokens, n_expts_act = topk_weights.shape - n_gates = n_tokens * n_expts_act - - hist, topk_indx, gate_indx, gate_scal = fused_routing_from_topk( - topk_weights, topk_ids, n_expts_tot - ) - - token_offs_raw, token_offs_pad, block_pid_map, blocks1a, BLOCK_A, block_m_log2 = ( - _compute_expt_data_internal(n_expts_tot, n_gates, block_m, topk_weights.device) - ) - - _expt_data_only_kernel[(blocks1a,)]( - hist, - n_expts_tot, - token_offs_raw, - token_offs_pad, - block_pid_map, - block_pid_map.shape[0], - n_gates, - block_m_log2, - BLOCK_A, - (hist.shape[0] == BLOCK_A), - num_warps=1, - ) - - expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) - routing_data = RoutingData( - block_m=block_m, - gate_scal=gate_scal, - expt_hist=hist, - n_expts_tot=n_expts_tot, - n_expts_act=n_expts_act, - expt_data=expt_data, - ) - return routing_data, topk_indx, gate_indx - - # -------------------------- # torch reference # -------------------------- diff --git a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py b/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py deleted file mode 100644 index 098b889a43..0000000000 --- a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py +++ /dev/null @@ -1,287 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. - -# Correctness test for ``aiter.ops.triton.fused_routing_from_topk``. -# -# The fused kernel skips the per-row sort that a stable-argsort reference -# performs, so its ``topk_indx``/``gate_indx`` may differ from the reference -# at *intra-expert* ordering. Equivalence is therefore checked at multiple -# levels: -# -# 1. Inverse-permutation invariant on the fused output. -# 2. Bucket invariant: items at expert-sorted positions for expert ``e`` -# reference (token, slot) pairs whose expert id equals ``e``. -# 3. ``hist`` matches the reference exactly. -# 4. Per-expert ``(token, weight)`` multisets match the reference (and a -# ground-truth bucket built directly from inputs). -import pytest -import torch - -from aiter.ops.triton.fusions.fused_routing_from_topk import fused_routing_from_topk - -DEVICE = "cuda" - - -# --------------------------------------------------------------------------- -# Reference implementation — mirrors the multi-kernel torch chain originally -# used to bridge FusedMoE.select_experts to triton_kernels.matmul_ogs. -# Returns ``(hist, topk_indx, gate_indx, gate_scal)`` for direct comparison -# against the fused kernel. -# --------------------------------------------------------------------------- -def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot, expert_map=None): - """Multi-kernel torch reference for fused_routing_from_topk. - - Per-row sort of ``topk_ids`` followed by a stable global argsort by - expert id, an inverse-permutation argsort, and an integer histogram. - The output indices are bit-exact stable across runs (modulo torch - version), unlike the fused kernel which is non-deterministic at - intra-expert ordering. - """ - if expert_map is not None: - local_ids = expert_map[topk_ids.long()] - invalid = local_ids < 0 - topk_weights = topk_weights.masked_fill(invalid, 0.0) - topk_ids = local_ids.masked_fill(invalid, 0).to(torch.int32) - - expt_indx_sorted, sort_indices = torch.sort(topk_ids.int(), dim=1) - expt_scal_sorted = torch.gather(topk_weights, 1, sort_indices.long()) - - expt_scal = expt_scal_sorted.reshape(-1).to(topk_weights.dtype) - expt_indx = expt_indx_sorted.reshape(-1).to(torch.int32) - - topk_indx = torch.argsort(expt_indx, stable=True).int() - gate_indx = torch.argsort(topk_indx, stable=True).int() - gate_scal = expt_scal[topk_indx.long()] - - hist = torch.histc(expt_indx.float(), bins=n_expts_tot, max=n_expts_tot - 1).int() - return hist, topk_indx, gate_indx, gate_scal - - -# --------------------------------------------------------------------------- -# Test helpers -# --------------------------------------------------------------------------- -def _make_inputs(n_tokens, n_expts_act, n_expts_tot, dtype, device, seed): - """Random topk-style inputs: distinct expert ids per row + L1-normalized - positive weights (matches FusedMoE.select_experts post-renormalize).""" - g = torch.Generator(device=device).manual_seed(seed) - ids = torch.empty(n_tokens, n_expts_act, dtype=torch.int32, device=device) - for n in range(n_tokens): - ids[n] = torch.randperm(n_expts_tot, generator=g, device=device)[ - :n_expts_act - ].to(torch.int32) - weights = torch.rand(n_tokens, n_expts_act, generator=g, device=device, dtype=dtype) - weights = weights / weights.sum(dim=-1, keepdim=True).clamp_min(1e-12) - return ids, weights - - -def _check_routing_invariants( - hist, - topk_indx, - gate_indx, - gate_scal, - topk_ids, - n_expts_tot, - *, - bucket_unsorted_layout, -): - """Sanity invariants that any valid fused-routing output must satisfy. - - ``bucket_unsorted_layout`` enables the bucket invariant against - ``topk_ids.flatten()`` directly. That only holds for the fused kernel - (which skips the per-row sort); a stable-argsort reference uses a - per-token-sorted flat layout, so its ``topk_indx`` indexes a different - array. - """ - n_tokens, K = topk_ids.shape - NK = n_tokens * K - device = topk_ids.device - - # 1. Inverse permutation: gate_indx[topk_indx[j]] == j for all j. - iota = torch.arange(NK, dtype=torch.int32, device=device) - inv_check = gate_indx[topk_indx.long()] - assert torch.equal(inv_check, iota), ( - "gate_indx[topk_indx[j]] != j (first mismatch at " - f"{(inv_check != iota).nonzero()[0].item()})" - ) - - # 2. hist is non-negative int32 and sums to NK. - assert hist.dtype == torch.int32, f"hist dtype != int32 (got {hist.dtype})" - assert (hist >= 0).all(), "hist has negative entries" - assert hist.sum().item() == NK, f"hist.sum()={hist.sum().item()} != NK={NK}" - - # 3. gate_scal is finite and same length as topk_indx. - assert gate_scal.numel() == NK - assert torch.isfinite(gate_scal).all(), "gate_scal has non-finite values" - - # 4. Bucket invariant (fused-only): items at expert-sorted positions - # [prefix[e], prefix[e+1]) reference original (token, slot) pairs - # whose expert id equals e in the *unsorted* topk_ids flat layout. - if bucket_unsorted_layout: - prefix_end = torch.cumsum(hist, dim=0).cpu().tolist() - flat_ids = topk_ids.reshape(-1).cpu().tolist() - src = topk_indx.cpu().tolist() - start = 0 - for e in range(n_expts_tot): - end = prefix_end[e] - for j in range(start, end): - assert flat_ids[src[j]] == e, ( - f"expert-sorted pos {j}: expected expert {e} " - f"but original_flat={src[j]} has expert " - f"{flat_ids[src[j]]}" - ) - start = end - - -def _ground_truth_buckets(topk_ids, topk_weights): - """Build the (token, weight) multiset per expert directly from the - inputs — independent of any routing implementation.""" - _, K = topk_ids.shape - flat_ids = topk_ids.reshape(-1).cpu().tolist() - flat_w = topk_weights.reshape(-1).float().cpu().tolist() - buckets: dict[int, list] = {} - for i, e in enumerate(flat_ids): - token = i // K - buckets.setdefault(e, []).append((token, flat_w[i])) - for e in buckets: - buckets[e].sort() - return buckets - - -def _per_expert_triples(hist, topk_indx, gate_scal, K): - """Walk the expert-sorted layout and bucket (token, weight) pairs by - expert id, using ``hist`` to determine each bucket's slice.""" - NK = topk_indx.numel() - n_expts_tot = hist.numel() - cum = torch.cumsum(hist, dim=0).cpu().tolist() - - src = topk_indx.cpu().tolist() - scal = gate_scal.float().cpu().tolist() - - buckets: dict[int, list] = {e: [] for e in range(n_expts_tot)} - e = 0 - for j in range(NK): - while e < n_expts_tot and j >= cum[e]: - e += 1 - original_flat = src[j] - token = original_flat // K - buckets[e].append((token, scal[j])) - for e in buckets: - buckets[e].sort() - return buckets - - -def _compare_buckets(ref_buckets, test_buckets, atol=1e-6): - keys = set(ref_buckets) | set(test_buckets) - for e in keys: - rb = ref_buckets.get(e, []) - tb = test_buckets.get(e, []) - assert len(rb) == len( - tb - ), f"expert {e}: bucket size ref={len(rb)} test={len(tb)}" - for (tt_r, w_r), (tt_t, w_t) in zip(rb, tb): - assert tt_r == tt_t, f"expert {e}: token mismatch ref={tt_r} test={tt_t}" - assert ( - abs(w_r - w_t) <= atol - ), f"expert {e}: token {tt_r} weight ref={w_r} test={w_t}" - - -# --------------------------------------------------------------------------- -# tests -# --------------------------------------------------------------------------- -@pytest.mark.parametrize( - "n_tokens, n_expts_act, n_expts_tot, n_expts_global", - [ - # V4-Flash decode shapes (E=256, K=6). n_expts_global ignored when - # has_expert_map=False. - (1, 6, 256, 256), - (16, 6, 256, 256), - (64, 6, 256, 256), - (256, 6, 256, 256), - # Generic decode shapes used by other MoE configs. - (1, 8, 384, 384), - (4, 8, 384, 384), - (64, 8, 384, 384), - (256, 8, 384, 384), - # Edge: small E. - (32, 4, 16, 16), - # Boundary: NK at the kernel's MAX_NK = 4096. - (512, 8, 384, 384), - # Expert-parallel shapes: n_expts_global > n_expts_tot, requires map. - (16, 6, 64, 256), - (64, 6, 128, 256), - ], -) -@pytest.mark.parametrize("has_expert_map", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float32]) -def test_fused_routing_from_topk( - n_tokens, n_expts_act, n_expts_tot, n_expts_global, has_expert_map, dtype -): - if not torch.cuda.is_available(): - pytest.skip("CUDA not available") - torch.manual_seed(0) - - id_range = n_expts_global if has_expert_map else n_expts_tot - topk_ids, topk_weights = _make_inputs( - n_tokens, n_expts_act, id_range, dtype, DEVICE, seed=0 - ) - - expert_map = None - if has_expert_map: - expert_map = torch.full((n_expts_global,), -1, dtype=torch.int32, device=DEVICE) - expert_map[: n_expts_tot // 2] = torch.arange( - n_expts_tot // 2, dtype=torch.int32, device=DEVICE - ) - - ref_hist, ref_topk_indx, ref_gate_indx, ref_gate_scal = routing_from_topk_reference( - topk_weights, topk_ids, n_expts_tot, expert_map=expert_map - ) - _check_routing_invariants( - ref_hist, - ref_topk_indx, - ref_gate_indx, - ref_gate_scal, - topk_ids, - n_expts_tot, - bucket_unsorted_layout=False, # ref uses per-row-sorted layout - ) - - test_hist, test_topk_indx, test_gate_indx, test_gate_scal = fused_routing_from_topk( - topk_weights, topk_ids, n_expts_tot, expert_map=expert_map - ) - _check_routing_invariants( - test_hist, - test_topk_indx, - test_gate_indx, - test_gate_scal, - topk_ids, - n_expts_tot, - bucket_unsorted_layout=not has_expert_map, - ) - - # hist must match the reference exactly. - assert torch.equal( - ref_hist, test_hist - ), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}" - - if has_expert_map: - # Intra-expert ordering can differ between fused and reference, - # especially in expert-0 bucket where invalid experts are redirected. - # Compare zeroed-weight cardinality instead of elementwise positions. - ref_zero_count = int((ref_gate_scal == 0).sum().item()) - test_zero_count = int((test_gate_scal == 0).sum().item()) - assert ref_zero_count == test_zero_count, ( - f"zero-masked count mismatch: " - f"ref={ref_zero_count}, fused={test_zero_count}" - ) - else: - ground_buckets = _ground_truth_buckets(topk_ids, topk_weights) - ref_buckets = _per_expert_triples( - ref_hist, ref_topk_indx, ref_gate_scal, n_expts_act - ) - _compare_buckets(ground_buckets, ref_buckets) - - # Per-expert (token, weight) multisets match the reference. - test_buckets = _per_expert_triples( - test_hist, test_topk_indx, test_gate_scal, n_expts_act - ) - _compare_buckets(ref_buckets, test_buckets) diff --git a/op_tests/triton_tests/moe/test_moe_routing.py b/op_tests/triton_tests/moe/test_moe_routing.py index 015d07497b..01e7620583 100644 --- a/op_tests/triton_tests/moe/test_moe_routing.py +++ b/op_tests/triton_tests/moe/test_moe_routing.py @@ -4,8 +4,7 @@ import torch.nn.functional as F from aiter.ops.triton.moe.moe_routing.routing import ( routing, - routing_a8w4_from_hash, - routing_a8w4_from_topk, + routing_from_hash, routing_torch, compute_expt_data_torch, ) @@ -163,7 +162,7 @@ def _assert_indx_equal(ref, tri): # -------------------------- -# Reference implementations for routing_a8w4* paths +# Reference implementations for routing with score mode paths # -------------------------- @@ -197,7 +196,7 @@ def _sort_and_build_torch(expt_scal, expt_indx, n_expts_tot, block_m): return hist, topk_indx, gate_indx, gate_scal, expt_data -def routing_a8w4_torch( +def routing_score_mode_torch( logits, n_expts_act, block_m, @@ -236,7 +235,7 @@ def routing_a8w4_torch( return _sort_and_build_torch(expt_scal, topk_ids, n_expts_tot, block_m) -def routing_a8w4_from_hash_torch( +def routing_from_hash_torch( router_logits, tid2eid, input_ids, @@ -269,15 +268,6 @@ def routing_a8w4_from_hash_torch( return _sort_and_build_torch(expt_scal, expt_indx, n_expts_tot, block_m) -def routing_a8w4_from_topk_torch(topk_weights, topk_ids, n_expts_tot, block_m): - return _sort_and_build_torch( - topk_weights, - topk_ids.to(torch.int16), - n_expts_tot, - block_m, - ) - - def _check_routing_data(ref_pack, tri_routing_data, tri_gather, tri_scatter): """Strict equality check: works when the triton sort and stable argsort agree on intra-bucket order (the sort_tokens / sort_tokens_fused path).""" @@ -325,7 +315,7 @@ def _check_routing_data_bucket( n_expts_tot = ref_hist.numel() # Inverse permutation invariant: gate_indx[topk_indx[j]] == j. - # Cast scatter to int64 first: the grouped routing_ds path returns uint16 + # Cast scatter to int64 first: the grouped routing_score_mode path returns uint16 # indices, which CUDA cannot advanced-index. iota = torch.arange(n_gates, dtype=torch.int64, device=tri_gather.device) assert torch.equal( @@ -373,7 +363,7 @@ def _check_routing_data_bucket( # -------------------------- -# routing_a8w4 +# routing score mode # -------------------------- @@ -396,7 +386,7 @@ def _check_routing_data_bucket( ("softmax", False, False, 1.0), # identity transform, no renorm ], ) -def test_routing_a8w4( +def test_routing_score_mode( n_tokens, n_expts_tot, n_expts_act, @@ -420,7 +410,7 @@ def test_routing_a8w4( # routing derives block_m internally; mirror that here for the ref. block_m = _routing_block_m(n_tokens, n_expts_act, n_expts_tot) - ref_pack = routing_a8w4_torch( + ref_pack = routing_score_mode_torch( logits.clone(), n_expts_act, block_m, @@ -445,7 +435,7 @@ def test_routing_a8w4( # -------------------------- -# routing_a8w4_from_hash +# routing_from_hash # -------------------------- @@ -466,7 +456,7 @@ def test_routing_a8w4( ], ) @pytest.mark.parametrize("block_m", [16, 32]) -def test_routing_a8w4_from_hash( +def test_routing_from_hash( n_tokens, n_expts_tot, n_expts_act, @@ -497,7 +487,7 @@ def test_routing_a8w4_from_hash( 0, vocab_size, (n_tokens,), dtype=torch.int32, device=device ) - ref_pack = routing_a8w4_from_hash_torch( + ref_pack = routing_from_hash_torch( router_logits.clone(), tid2eid, input_ids, @@ -507,7 +497,7 @@ def test_routing_a8w4_from_hash( renorm=renorm, routed_scaling_factor=routed_scaling_factor, ) - tri_routing_data, tri_gather, tri_scatter = routing_a8w4_from_hash( + tri_routing_data, tri_gather, tri_scatter = routing_from_hash( router_logits, tid2eid, input_ids, @@ -524,72 +514,6 @@ def test_routing_a8w4_from_hash( assert tri_routing_data.block_m == block_m -# -------------------------- -# routing_a8w4_from_topk -# -------------------------- - - -# fused_routing_from_topk requires n_tokens * n_expts_act <= 4096. -@pytest.mark.parametrize( - "n_tokens, n_expts_tot, n_expts_act", - [ - (8, 128, 4), - (64, 128, 4), - (256, 128, 4), - (256, 256, 8), - (512, 128, 8), - ], -) -@pytest.mark.parametrize("block_m", [16, 32]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) -def test_routing_a8w4_from_topk( - n_tokens, - n_expts_tot, - n_expts_act, - block_m, - dtype, -): - if get_arch() not in ["gfx950", "gfx1250"]: - pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") - - device = "cuda" - torch.manual_seed(2) - topk_weights = torch.randn(n_tokens, n_expts_act, dtype=dtype, device=device).abs() - # Per-row unique expert ids (the natural V4 case). - topk_ids = torch.stack( - [ - torch.randperm(n_expts_tot, device=device)[:n_expts_act] - for _ in range(n_tokens) - ], - dim=0, - ).to(torch.int32) - - ref_pack = routing_a8w4_from_topk_torch( - topk_weights.clone(), - topk_ids.clone(), - n_expts_tot, - block_m, - ) - tri_routing_data, tri_gather, tri_scatter = routing_a8w4_from_topk( - topk_weights, - topk_ids, - n_expts_tot, - block_m, - ) - - _check_routing_data_bucket( - ref_pack, - tri_routing_data, - tri_gather, - tri_scatter, - topk_weights, - topk_ids, - ) - assert tri_routing_data.n_expts_tot == n_expts_tot - assert tri_routing_data.n_expts_act == n_expts_act - assert tri_routing_data.block_m == block_m - - # ========================================================================== # grouped-top-k routing (aiter.ops.triton.moe.moe_routing.topk.grouped_topk) # Moved from test_grouped_topk.py. Reuses the shared helpers above @@ -784,28 +708,87 @@ def _maybe_skip(): # -------------------------------------------------------------------------- # 1. direct kernel test: (y_vals, y_indx, bitmatrix) +# +# Unified across contiguous/arbitrary expert->group layouts, score modes, and +# 0/1/2 fused always-on shared experts. The curated case list reproduces the +# original three tests' coverage exactly (contiguous x SCORE_COMBOS x no shared; +# arbitrary x sqrtsoftplus; contiguous x sqrtsoftplus x shared 1/2). # -------------------------------------------------------------------------- +# sqrtsoftplus + bias + renorm + scale=2.5: the fixed combo the arbitrary-group +# and shared-expert variants exercise. +SQ_COMBO = ("sqrtsoftplus", True, True, 2.5) + + +def _make_shuffled_expert_group(n_expts_tot, num_expert_group, device): + """Equal-size groups with a shuffled (non-contiguous) expert->group table.""" + g_size = n_expts_tot // num_expert_group + perm = torch.randperm(n_expts_tot, device=device) + expert_group = torch.empty(n_expts_tot, dtype=torch.int32, device=device) + for g in range(num_expert_group): + expert_group[perm[g * g_size : (g + 1) * g_size]] = g + return expert_group + + +def _grouped_topk_kernel_cases(): + cases = [] + # (1) contiguous groups, all score combos, no shared experts. + for nt in GROUPED_N_TOKENS: + for shape in GROUP_SHAPES: + for sc in SCORE_COMBOS: + cases.append( + pytest.param( + nt, + shape, + sc, + "contiguous", + 0, + id=f"contig-nt{nt}-e{shape[0]}-{sc[0]}-s0", + ) + ) + # (2) arbitrary (non-contiguous) expert->group table, fixed sqrtsoftplus. + for nt in [8, 64, 1024]: + for shape in GROUP_SHAPES: + cases.append( + pytest.param( + nt, + shape, + SQ_COMBO, + "arbitrary", + 0, + id=f"arb-nt{nt}-e{shape[0]}-s0", + ) + ) + # (3) contiguous groups, fixed sqrtsoftplus, 1/2 fused shared experts. + for nt in [8, 64, 1024]: + for shape in GROUP_SHAPES: + for ns in [1, 2]: + cases.append( + pytest.param( + nt, + shape, + SQ_COMBO, + "contiguous", + ns, + id=f"contig-nt{nt}-e{shape[0]}-s{ns}", + ) + ) + return cases + -@pytest.mark.parametrize("n_tokens", GROUPED_N_TOKENS) @pytest.mark.parametrize( - "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES + "n_tokens, shape, score_combo, group_mode, n_shared", + _grouped_topk_kernel_cases(), ) -@pytest.mark.parametrize("score_mode, has_bias, renorm, scale", SCORE_COMBOS) -def test_grouped_topk_kernel( - n_tokens, - n_expts_tot, - num_expert_group, - topk_group, - n_expts_act, - score_mode, - has_bias, - renorm, - scale, -): +def test_grouped_topk_kernel(n_tokens, shape, score_combo, group_mode, n_shared): + """Direct grouped_topk kernel test: routed selection + bitmatrix vs torch + reference, parametrized over expert->group layout, score mode, and fused + shared experts.""" _maybe_skip() + n_expts_tot, num_expert_group, topk_group, n_expts_act = shape + score_mode, has_bias, renorm, scale = score_combo device = "cuda" - torch.manual_seed(2) + torch.manual_seed(7 if group_mode == "arbitrary" else 2) logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) bias = ( torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 @@ -813,91 +796,75 @@ def test_grouped_topk_kernel( else None ) - ref_w, ref_ids = _ref_contiguous( - logits.clone(), - n_expts_act, - num_expert_group, - topk_group, - score_mode, - bias, - renorm, - scale, - ) + if group_mode == "arbitrary": + expert_group = _make_shuffled_expert_group( + n_expts_tot, num_expert_group, device + ) + ref_w, ref_ids = _ref_arbitrary_grouped( + logits.clone(), + expert_group, + n_expts_act, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, + ) + else: + expert_group = None + ref_w, ref_ids = _ref_contiguous( + logits.clone(), + n_expts_act, + num_expert_group, + topk_group, + score_mode, + bias, + renorm, + scale, + ) + y_vals, y_indx, bitmatrix = grouped_topk( logits, n_expts_act, num_expert_group=num_expert_group, topk_group=topk_group, + expert_group=expert_group, score_mode=score_mode, bias=bias, renorm=renorm, routed_scaling_factor=scale, + num_fused_shared_experts=n_shared, + shared_experts_score=1.0, ) - assert y_vals.shape == (n_tokens, n_expts_act) - assert y_indx.shape == (n_tokens, n_expts_act) + assert y_vals.shape == (n_tokens, n_expts_act + n_shared) + assert y_indx.shape == (n_tokens, n_expts_act + n_shared) assert y_indx.dtype == torch.int16 assert y_vals.dtype == logits.dtype - _assert_selection_matches(ref_ids, ref_w, y_indx, y_vals) - _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot) - - -# -------------------------------------------------------------------------- -# 2. arbitrary (non-contiguous) expert->group mapping -# -------------------------------------------------------------------------- - - -@pytest.mark.parametrize("n_tokens", [8, 64, 1024]) -@pytest.mark.parametrize( - "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES -) -def test_grouped_topk_arbitrary_group( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act -): - _maybe_skip() - device = "cuda" - torch.manual_seed(7) - logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) - bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 - - # Equal-size groups but a shuffled (non-contiguous) expert->group table. - g_size = n_expts_tot // num_expert_group - perm = torch.randperm(n_expts_tot, device=device) - expert_group = torch.empty(n_expts_tot, dtype=torch.int32, device=device) - for g in range(num_expert_group): - expert_group[perm[g * g_size : (g + 1) * g_size]] = g - - score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - ref_w, ref_ids = _ref_arbitrary_grouped( - logits.clone(), - expert_group, - n_expts_act, - num_expert_group, - topk_group, - score_mode, - bias, - renorm, - scale, - ) - y_vals, y_indx, bitmatrix = grouped_topk( - logits, - n_expts_act, - num_expert_group=num_expert_group, - topk_group=topk_group, - expert_group=expert_group, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=scale, + # Routed slots (first n_expts_act) must match the reference selection. + _assert_selection_matches( + ref_ids, ref_w, y_indx[:, :n_expts_act], y_vals[:, :n_expts_act] ) - _assert_selection_matches(ref_ids, ref_w, y_indx, y_vals) - _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot) + # Shared slots: fixed id n_expts_tot+i, weight 1.0, for every token. + for i in range(n_shared): + ids_i = y_indx[:, n_expts_act + i].cpu().long() + w_i = y_vals[:, n_expts_act + i].float().cpu() + assert torch.all(ids_i == n_expts_tot + i), f"shared id col {i}: {ids_i}" + assert torch.allclose(w_i, torch.ones(n_tokens)), f"shared weight col {i}" + + _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot + n_shared) # -------------------------------------------------------------------------- -# 3. end-to-end routing_ds(use_grouped_topk=True) +# 3. end-to-end routing_score_mode(use_grouped_topk=True) +# +# Unified with the former test_routing_score_mode_grouped_shared: n_shared in {0,1,2}. +# grouped_topk (with the same n_shared) is the deterministic ground truth, and +# _check_routing_data_bucket validates hist / ExptData / inverse-permutation / +# per-expert (token, weight) multisets over the (widened) gate count. # -------------------------------------------------------------------------- @@ -905,9 +872,15 @@ def test_grouped_topk_arbitrary_group( @pytest.mark.parametrize( "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES ) -def test_routing_ds_grouped( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act +@pytest.mark.parametrize("n_shared", [0, 1, 2]) +def test_routing_score_mode_grouped( + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, n_shared ): + """End-to-end routing(use_grouped_topk=True) with 0/1/2 fused always-on + shared experts. The routed selection must match the grouped_topk kernel, the + histogram must include a full shared bucket (n_tokens) per shared expert, and + gather/scatter must form a valid inverse permutation over the widened gate + count.""" _maybe_skip() device = "cuda" torch.manual_seed(2) @@ -915,11 +888,12 @@ def test_routing_ds_grouped( bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - # routing derives block_m internally; mirror that here for the ref. + # routing derives block_m from the raw shape + pre-widening n_expts_act. block_m = _routing_block_m(n_tokens, n_expts_act, n_expts_tot) - # The selection the kernel makes (deterministic for fixed inputs); used as - # ground truth for the sort/scatter pipeline check. + # The selection the kernel makes (deterministic for fixed inputs), including + # the appended shared experts; used as ground truth for the sort/scatter + # pipeline check. y_vals, y_indx, _ = grouped_topk( logits, n_expts_act, @@ -929,6 +903,8 @@ def test_routing_ds_grouped( bias=bias, renorm=renorm, routed_scaling_factor=scale, + num_fused_shared_experts=n_shared, + shared_experts_score=1.0, ) tri_routing_data, tri_gather, tri_scatter = routing( @@ -941,127 +917,23 @@ def test_routing_ds_grouped( use_grouped_topk=True, num_expert_group=num_expert_group, topk_group=topk_group, + num_fused_shared_experts=n_shared, ) + n_total = n_expts_tot + n_shared ref_pack = _sort_and_build_torch( - y_vals.float(), y_indx.to(torch.int32), n_expts_tot, block_m + y_vals.float(), y_indx.to(torch.int32), n_total, block_m ) _check_routing_data_bucket( ref_pack, tri_routing_data, tri_gather, tri_scatter, y_vals.float(), y_indx ) - assert tri_routing_data.n_expts_tot == n_expts_tot - assert tri_routing_data.n_expts_act == n_expts_act + assert tri_routing_data.n_expts_tot == n_total + assert tri_routing_data.n_expts_act == n_expts_act + n_shared assert tri_routing_data.block_m == block_m - -# -------------------------------------------------------------------------- -# 4. fused shared experts (DeepSeek-R1/V3 always-on shared expert) -# -------------------------------------------------------------------------- - - -@pytest.mark.parametrize("n_tokens", [8, 64, 1024]) -@pytest.mark.parametrize( - "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES -) -@pytest.mark.parametrize("n_shared", [1, 2]) -def test_grouped_topk_shared_expert( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, n_shared -): - """The kernel appends `n_shared` always-on shared experts (id n_expts_tot+i, - weight 1.0) AFTER the routed renorm. The routed portion must still match the - reference, and the shared columns + bitmatrix must reflect the append.""" - _maybe_skip() - device = "cuda" - torch.manual_seed(2) - logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) - bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 - score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - - ref_w, ref_ids = _ref_contiguous( - logits.clone(), - n_expts_act, - num_expert_group, - topk_group, - score_mode, - bias, - renorm, - scale, - ) - y_vals, y_indx, bitmatrix = grouped_topk( - logits, - n_expts_act, - num_expert_group=num_expert_group, - topk_group=topk_group, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=scale, - num_fused_shared_experts=n_shared, - shared_experts_score=1.0, - ) - - assert y_vals.shape == (n_tokens, n_expts_act + n_shared) - assert y_indx.shape == (n_tokens, n_expts_act + n_shared) - - # Routed slots (first n_expts_act) must match the reference selection. - _assert_selection_matches( - ref_ids, ref_w, y_indx[:, :n_expts_act], y_vals[:, :n_expts_act] - ) - - # Shared slots: fixed id n_expts_tot+i, weight 1.0, for every token. + # Each fused shared expert is an always-on bucket: every token routed once. for i in range(n_shared): - ids_i = y_indx[:, n_expts_act + i].cpu().long() - w_i = y_vals[:, n_expts_act + i].float().cpu() - assert torch.all(ids_i == n_expts_tot + i), f"shared id col {i}: {ids_i}" - assert torch.allclose(w_i, torch.ones(n_tokens)), f"shared weight col {i}" - - # Bitmatrix must contain routed + shared selections over the widened width. - _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot + n_shared) - - -@pytest.mark.parametrize("n_tokens", [8, 16, 64, 1024]) -@pytest.mark.parametrize( - "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES -) -@pytest.mark.parametrize("n_shared", [1, 2]) -def test_routing_ds_grouped_shared( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, n_shared -): - """End-to-end routing with fused shared experts: histogram must - include a full shared bucket (n_tokens) per shared expert and the - gather/scatter must form a valid inverse permutation over the widened gate - count.""" - _maybe_skip() - device = "cuda" - torch.manual_seed(2) - logits = init_data(n_tokens, n_expts_tot, device=device, dtype=torch.float32) - bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 - score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - - rd, gather, scatter = routing( - logits, - n_expts_act, - score_mode=score_mode, - bias=bias, - renorm=renorm, - routed_scaling_factor=scale, - use_grouped_topk=True, - num_expert_group=num_expert_group, - topk_group=topk_group, - num_fused_shared_experts=n_shared, - ) - - assert rd.n_expts_tot == n_expts_tot + n_shared - assert rd.n_expts_act == n_expts_act + n_shared - - # Every token is routed to each shared expert exactly once. - for i in range(n_shared): - assert rd.expt_hist[n_expts_tot + i].item() == n_tokens - assert rd.expt_hist.sum().item() == n_tokens * (n_expts_act + n_shared) - - n_gates = n_tokens * (n_expts_act + n_shared) - iota = torch.arange(n_gates, dtype=torch.int32, device=gather.device) - assert torch.equal(scatter.long()[gather.long()], iota), "scatter[gather[j]] != j" + assert tri_routing_data.expt_hist[n_expts_tot + i].item() == n_tokens def bench_routing(): From 7d3011a0b53f7595b9fdce965262901326ce535d Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Thu, 4 Jun 2026 22:59:29 +0000 Subject: [PATCH 15/18] ruff --- aiter/ops/triton/moe/moe_routing/routing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index 590cac034d..e7593099a6 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -8,9 +8,6 @@ from aiter.ops.triton.fusions.fused_routing_from_topk import ( fused_routing_from_topk, ) -from aiter.ops.triton._triton_kernels.moe.moe_routing.expt_data import ( - _expt_data_only_kernel, -) from aiter.ops.triton.utils._triton.arch_info import is_tdm_avail from aiter.ops.triton.moe.moe_routing.topk import grouped_topk From 682213efa4f547beb407bc6921f06d997f4b5a25 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Thu, 4 Jun 2026 23:01:22 +0000 Subject: [PATCH 16/18] ruff --- aiter/ops/triton/moe/moe_routing/routing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index e7593099a6..f0d3a332bb 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -5,9 +5,6 @@ _combined_routing, _combined_routing_fused, ) -from aiter.ops.triton.fusions.fused_routing_from_topk import ( - fused_routing_from_topk, -) from aiter.ops.triton.utils._triton.arch_info import is_tdm_avail from aiter.ops.triton.moe.moe_routing.topk import grouped_topk From 6289847d3df68606858dd675cee1d35c4f21a159 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Fri, 5 Jun 2026 14:19:14 +0000 Subject: [PATCH 17/18] small change for routing without fused shared experts in topk score_mode branch --- aiter/ops/triton/moe/moe_routing/routing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index f0d3a332bb..ef086648ff 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -331,16 +331,19 @@ def routing( ) # ------------------------------------------------------------------ - # fused V4 path: fused routing math + sort (score_mode given) + # fused path: fused routing math + sort (score_mode given) # ------------------------------------------------------------------ n_shared = num_fused_shared_experts - n_expts_tot = n_routed + n_shared + n_expts_tot = n_routed if use_grouped_topk and num_expert_group != 1: assert ( num_expert_group is not None and topk_group is not None ), "use_grouped_topk requires num_expert_group and topk_group" + # grouped topk is fused to use num fused shared experts as an expert to route into. + n_expts_tot += n_shared + expt_scal, expt_indx, bitmatrix = grouped_topk( logits, n_expts_act, From 9f83e5a681d1a1fcab682a899379a8423df27d45 Mon Sep 17 00:00:00 2001 From: Amelia Moore Date: Fri, 5 Jun 2026 18:59:35 +0000 Subject: [PATCH 18/18] change to routing to not support fused shared experts --- .../_triton_kernels/moe/moe_routing/topk.py | 25 +----- aiter/ops/triton/moe/moe_routing/routing.py | 30 ++----- aiter/ops/triton/moe/moe_routing/topk.py | 24 +----- op_tests/triton_tests/moe/test_moe_routing.py | 83 +++++-------------- 4 files changed, 37 insertions(+), 125 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py index 7da43d1e54..69d4c5b833 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py @@ -460,8 +460,6 @@ def _grouped_topk( HAS_BIAS: tl.constexpr = False, APPLY_RENORM: tl.constexpr = False, ROUTED_SCALING: tl.constexpr = 1.0, - N_SHARED: tl.constexpr = 0, - SHARED_SCORE: tl.constexpr = 1.0, ): pid = tl.program_id(0) @@ -596,14 +594,12 @@ def _grouped_topk( y_weights = tl.sum(tl.where(pos_eq, scores_3d, 0.0), axis=2) # [BLOCK_M, K_PAD] # Routed-slot mask: the first N_EXPTS_ACT slots hold the grouped-topk - # selection (shared experts, if any, occupy the next N_SHARED slots and - # must be excluded from the routed renorm denominator). + # selection; the remaining padded slots are masked off. k_arange = tl.arange(0, N_EXPTS_ACT_PAD) routed_mask = k_arange[None, :] < N_EXPTS_ACT - # -- 9. Renorm + scale over the ROUTED slots only (mirrors _topk's - # APPLY_RENORM / ROUTED_SCALING and the noaux_tc semantics where the - # always-on shared expert is appended unscaled after renorm). + # -- 9. Renorm + scale over the ROUTED slots (mirrors _topk's + # APPLY_RENORM / ROUTED_SCALING). if APPLY_RENORM: y_f = tl.where(routed_mask, y_weights, 0.0) s = tl.sum(y_f, axis=1, keep_dims=True) @@ -611,20 +607,7 @@ def _grouped_topk( elif ROUTED_SCALING != 1.0: y_weights = y_weights * ROUTED_SCALING - # -- 9b. Append fused shared expert(s): always-on, fixed id n_expts_tot+i - # and fixed weight SHARED_SCORE (matches init_aiter_topK_meta_data / - # rocm_aiter_grouped_topk). Placed AFTER renorm so the shared weight - # is not folded into the routed normalization. - if N_SHARED > 0: - shared_slot = (k_arange[None, :] >= N_EXPTS_ACT) & ( - k_arange[None, :] < N_EXPTS_ACT + N_SHARED - ) - shared_idx = (n_expts_tot + k_arange - N_EXPTS_ACT)[None, :].to(tl.int32) - y_indices = tl.where(shared_slot, shared_idx, y_indices) - y_weights = tl.where(shared_slot, SHARED_SCORE, y_weights) - real_mask = k_arange[None, :] < (N_EXPTS_ACT + N_SHARED) - else: - real_mask = routed_mask + real_mask = routed_mask y_values_out = y_weights.to(x_dtype) diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index ef086648ff..a8b30860d0 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -265,32 +265,29 @@ def routing( num_expert_group: int | None = None, topk_group: int | None = None, expert_group: torch.Tensor | None = None, - num_fused_shared_experts: int = 0, - shared_experts_score: float = 1.0, ): """Routing entry point. ``score_mode`` selects the path: * ``score_mode is None`` (default) -> the plain flat top-k routing process: flat top-k with softmax. ``sm_first`` controls whether softmax is applied to the logits before the top-k (``True``) or inside the top-k - (``False``). The fused-V4-only arguments (``bias``, ``use_grouped_topk``, - ``num_fused_shared_experts`` ...) are ignored on this path. + (``False``). The fused-V4-only arguments (``bias``, ``use_grouped_topk`` + ...) are ignored on this path. * ``score_mode is not None`` -> the fused V4 (DeepSeek) routing process: - fused score transform + (optionally grouped) top-k + fused shared - experts. ``sm_first`` is ignored on this path. + fused score transform + (optionally grouped) top-k. ``sm_first`` is + ignored on this path. ``block_m`` is not supplied by the caller: it is derived internally from the - raw ``logits`` shape and the originally requested ``n_expts_act`` (before - any shared-expert widening). + raw ``logits`` shape and the originally requested ``n_expts_act``. Returns ``(RoutingData, gather_indx, scatter_indx)``. """ - num_tokens, n_routed = logits.shape + num_tokens, n_expts_tot = logits.shape # block_m heuristic from the raw logits shape and the originally requested # n_expts_act. m = num_tokens * n_expts_act - tokens_per_expt = max(1, m // n_routed) + tokens_per_expt = max(1, m // n_expts_tot) block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) # ------------------------------------------------------------------ @@ -308,7 +305,6 @@ def routing( apply_softmax=not sm_first, HIST_BLOCK_M=HIST_BLOCK_M, ) - n_expts_tot = n_routed if num_tokens <= 16: HIST_BLOCK_M = triton.next_power_of_2(num_tokens) sort_fn = sort_tokens_fused @@ -333,17 +329,12 @@ def routing( # ------------------------------------------------------------------ # fused path: fused routing math + sort (score_mode given) # ------------------------------------------------------------------ - n_shared = num_fused_shared_experts - n_expts_tot = n_routed if use_grouped_topk and num_expert_group != 1: assert ( num_expert_group is not None and topk_group is not None ), "use_grouped_topk requires num_expert_group and topk_group" - # grouped topk is fused to use num fused shared experts as an expert to route into. - n_expts_tot += n_shared - expt_scal, expt_indx, bitmatrix = grouped_topk( logits, n_expts_act, @@ -355,16 +346,9 @@ def routing( bias=bias, renorm=renorm, routed_scaling_factor=routed_scaling_factor, - num_fused_shared_experts=n_shared, - shared_experts_score=shared_experts_score, HIST_BLOCK_M=32, ) - # Routed top-k + appended shared experts per token. - n_expts_act = n_expts_act + n_shared else: - assert ( - n_shared == 0 - ), "fused shared experts are only supported on the grouped-topk path" from .topk import topk expt_scal, expt_indx, bitmatrix = topk( diff --git a/aiter/ops/triton/moe/moe_routing/topk.py b/aiter/ops/triton/moe/moe_routing/topk.py index 8f2097122c..82cb489299 100644 --- a/aiter/ops/triton/moe/moe_routing/topk.py +++ b/aiter/ops/triton/moe/moe_routing/topk.py @@ -21,24 +21,14 @@ def grouped_topk( bias: torch.Tensor | None = None, renorm: bool = False, routed_scaling_factor: float = 1.0, - num_fused_shared_experts: int = 0, - shared_experts_score: float = 1.0, ): """Triton grouped top-k expert selection. See module docstring. Returns ``(y_vals, y_indx, bitmatrix)`` matching the contract of ``aiter.ops.triton.moe.moe_routing.topk.topk``: - - y_vals: ``(n_rows, k + num_fused_shared_experts)`` in ``x.dtype``. - - y_indx: ``(n_rows, k + num_fused_shared_experts)`` ``int16``. - - When ``num_fused_shared_experts > 0`` the routed top-k selection occupies - the first ``k`` columns and the always-on shared expert(s) occupy the next - ``num_fused_shared_experts`` columns — expert id ``n_cols + i``, weight - ``shared_experts_score`` (appended after the routed renorm, mirroring - ``init_aiter_topK_meta_data`` / ``rocm_aiter_grouped_topk``). The bitmatrix - is widened to ``n_cols + num_fused_shared_experts`` columns so ``sort_tokens`` - counts the shared bucket. + - y_vals: ``(n_rows, k)`` in ``x.dtype``. + - y_indx: ``(n_rows, k)`` ``int16``. - bitmatrix: real :class:`Bitmatrix`; same uint32 ``(n_cols_words, n_rows_pad32).T`` storage / scratchpad layout the @@ -48,12 +38,8 @@ def grouped_topk( assert x.dim() == 2 n_rows, n_cols = x.shape assert n_cols <= 256, f"grouped_topk n_expts_tot ({n_cols}) only supported <= 256" - # Fused shared experts are appended (always-on) AFTER the routed selection; - # they occupy expert ids [n_cols, n_cols + num_fused_shared_experts). - n_shared = num_fused_shared_experts - assert n_shared >= 0 - n_total = n_cols + n_shared # experts incl. shared (bitmatrix width) - k_out = k + n_shared # output width (routed top-k + shared) + n_total = n_cols # experts (bitmatrix width) + k_out = k # output width (routed top-k) assert num_expert_group > 1 assert ( num_expert_group <= 16 @@ -165,8 +151,6 @@ def grouped_topk( HAS_BIAS=has_bias, APPLY_RENORM=renorm, ROUTED_SCALING=routed_scaling_factor, - N_SHARED=n_shared, - SHARED_SCORE=shared_experts_score, num_warps=4, ) diff --git a/op_tests/triton_tests/moe/test_moe_routing.py b/op_tests/triton_tests/moe/test_moe_routing.py index 01e7620583..68b05c7ec9 100644 --- a/op_tests/triton_tests/moe/test_moe_routing.py +++ b/op_tests/triton_tests/moe/test_moe_routing.py @@ -732,7 +732,7 @@ def _make_shuffled_expert_group(n_expts_tot, num_expert_group, device): def _grouped_topk_kernel_cases(): cases = [] - # (1) contiguous groups, all score combos, no shared experts. + # (1) contiguous groups, all score combos. for nt in GROUPED_N_TOKENS: for shape in GROUP_SHAPES: for sc in SCORE_COMBOS: @@ -742,8 +742,7 @@ def _grouped_topk_kernel_cases(): shape, sc, "contiguous", - 0, - id=f"contig-nt{nt}-e{shape[0]}-{sc[0]}-s0", + id=f"contig-nt{nt}-e{shape[0]}-{sc[0]}", ) ) # (2) arbitrary (non-contiguous) expert->group table, fixed sqrtsoftplus. @@ -755,35 +754,19 @@ def _grouped_topk_kernel_cases(): shape, SQ_COMBO, "arbitrary", - 0, - id=f"arb-nt{nt}-e{shape[0]}-s0", + id=f"arb-nt{nt}-e{shape[0]}", ) ) - # (3) contiguous groups, fixed sqrtsoftplus, 1/2 fused shared experts. - for nt in [8, 64, 1024]: - for shape in GROUP_SHAPES: - for ns in [1, 2]: - cases.append( - pytest.param( - nt, - shape, - SQ_COMBO, - "contiguous", - ns, - id=f"contig-nt{nt}-e{shape[0]}-s{ns}", - ) - ) return cases @pytest.mark.parametrize( - "n_tokens, shape, score_combo, group_mode, n_shared", + "n_tokens, shape, score_combo, group_mode", _grouped_topk_kernel_cases(), ) -def test_grouped_topk_kernel(n_tokens, shape, score_combo, group_mode, n_shared): +def test_grouped_topk_kernel(n_tokens, shape, score_combo, group_mode): """Direct grouped_topk kernel test: routed selection + bitmatrix vs torch - reference, parametrized over expert->group layout, score mode, and fused - shared experts.""" + reference, parametrized over expert->group layout and score mode.""" _maybe_skip() n_expts_tot, num_expert_group, topk_group, n_expts_act = shape score_mode, has_bias, renorm, scale = score_combo @@ -834,12 +817,10 @@ def test_grouped_topk_kernel(n_tokens, shape, score_combo, group_mode, n_shared) bias=bias, renorm=renorm, routed_scaling_factor=scale, - num_fused_shared_experts=n_shared, - shared_experts_score=1.0, ) - assert y_vals.shape == (n_tokens, n_expts_act + n_shared) - assert y_indx.shape == (n_tokens, n_expts_act + n_shared) + assert y_vals.shape == (n_tokens, n_expts_act) + assert y_indx.shape == (n_tokens, n_expts_act) assert y_indx.dtype == torch.int16 assert y_vals.dtype == logits.dtype @@ -848,23 +829,15 @@ def test_grouped_topk_kernel(n_tokens, shape, score_combo, group_mode, n_shared) ref_ids, ref_w, y_indx[:, :n_expts_act], y_vals[:, :n_expts_act] ) - # Shared slots: fixed id n_expts_tot+i, weight 1.0, for every token. - for i in range(n_shared): - ids_i = y_indx[:, n_expts_act + i].cpu().long() - w_i = y_vals[:, n_expts_act + i].float().cpu() - assert torch.all(ids_i == n_expts_tot + i), f"shared id col {i}: {ids_i}" - assert torch.allclose(w_i, torch.ones(n_tokens)), f"shared weight col {i}" - - _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot + n_shared) + _assert_bitmatrix_matches(bitmatrix, y_indx, n_tokens, n_expts_tot) # -------------------------------------------------------------------------- # 3. end-to-end routing_score_mode(use_grouped_topk=True) # -# Unified with the former test_routing_score_mode_grouped_shared: n_shared in {0,1,2}. -# grouped_topk (with the same n_shared) is the deterministic ground truth, and -# _check_routing_data_bucket validates hist / ExptData / inverse-permutation / -# per-expert (token, weight) multisets over the (widened) gate count. +# grouped_topk is the deterministic ground truth, and _check_routing_data_bucket +# validates hist / ExptData / inverse-permutation / per-expert (token, weight) +# multisets over the gate count. # -------------------------------------------------------------------------- @@ -872,15 +845,12 @@ def test_grouped_topk_kernel(n_tokens, shape, score_combo, group_mode, n_shared) @pytest.mark.parametrize( "n_expts_tot, num_expert_group, topk_group, n_expts_act", GROUP_SHAPES ) -@pytest.mark.parametrize("n_shared", [0, 1, 2]) def test_routing_score_mode_grouped( - n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act, n_shared + n_tokens, n_expts_tot, num_expert_group, topk_group, n_expts_act ): - """End-to-end routing(use_grouped_topk=True) with 0/1/2 fused always-on - shared experts. The routed selection must match the grouped_topk kernel, the - histogram must include a full shared bucket (n_tokens) per shared expert, and - gather/scatter must form a valid inverse permutation over the widened gate - count.""" + """End-to-end routing(use_grouped_topk=True). The routed selection must + match the grouped_topk kernel and gather/scatter must form a valid inverse + permutation over the gate count.""" _maybe_skip() device = "cuda" torch.manual_seed(2) @@ -888,12 +858,11 @@ def test_routing_score_mode_grouped( bias = torch.randn(n_expts_tot, dtype=torch.float32, device=device) * 0.05 score_mode, renorm, scale = "sqrtsoftplus", True, 2.5 - # routing derives block_m from the raw shape + pre-widening n_expts_act. + # routing derives block_m from the raw shape + n_expts_act. block_m = _routing_block_m(n_tokens, n_expts_act, n_expts_tot) - # The selection the kernel makes (deterministic for fixed inputs), including - # the appended shared experts; used as ground truth for the sort/scatter - # pipeline check. + # The selection the kernel makes (deterministic for fixed inputs); used as + # ground truth for the sort/scatter pipeline check. y_vals, y_indx, _ = grouped_topk( logits, n_expts_act, @@ -903,8 +872,6 @@ def test_routing_score_mode_grouped( bias=bias, renorm=renorm, routed_scaling_factor=scale, - num_fused_shared_experts=n_shared, - shared_experts_score=1.0, ) tri_routing_data, tri_gather, tri_scatter = routing( @@ -917,24 +884,18 @@ def test_routing_score_mode_grouped( use_grouped_topk=True, num_expert_group=num_expert_group, topk_group=topk_group, - num_fused_shared_experts=n_shared, ) - n_total = n_expts_tot + n_shared ref_pack = _sort_and_build_torch( - y_vals.float(), y_indx.to(torch.int32), n_total, block_m + y_vals.float(), y_indx.to(torch.int32), n_expts_tot, block_m ) _check_routing_data_bucket( ref_pack, tri_routing_data, tri_gather, tri_scatter, y_vals.float(), y_indx ) - assert tri_routing_data.n_expts_tot == n_total - assert tri_routing_data.n_expts_act == n_expts_act + n_shared + assert tri_routing_data.n_expts_tot == n_expts_tot + assert tri_routing_data.n_expts_act == n_expts_act assert tri_routing_data.block_m == block_m - # Each fused shared expert is an always-on bucket: every token routed once. - for i in range(n_shared): - assert tri_routing_data.expt_hist[n_expts_tot + i].item() == n_tokens - def bench_routing(): import triton.profiler as proton