diff --git a/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv index 119932fa65..c6521c6762 100644 --- a/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv @@ -1,33 +1,50 @@ M,N,K,q_dtype_w -64,1536,5120,torch.int8 -128,1536,5120,torch.int8 -256,1536,5120,torch.int8 -512,1536,5120,torch.int8 -1024,1536,5120,torch.int8 -1664,1536,5120,torch.int8 -4096,1536,5120,torch.int8 -8192,1536,5120,torch.int8 -10240,1536,5120,torch.int8 -12288,1536,5120,torch.int8 -16384,1536,5120,torch.int8 -20480,1536,5120,torch.int8 -24576,1536,5120,torch.int8 -30720,1536,5120,torch.int8 -32768,1536,5120,torch.int8 -40960,1536,5120,torch.int8 -64,5120,1280,torch.int8 -128,5120,1280,torch.int8 -256,5120,1280,torch.int8 -512,5120,1280,torch.int8 -1024,5120,1280,torch.int8 -1664,5120,1280,torch.int8 -4096,5120,1280,torch.int8 -8192,5120,1280,torch.int8 -10240,5120,1280,torch.int8 -12288,5120,1280,torch.int8 -16384,5120,1280,torch.int8 -20480,5120,1280,torch.int8 -24576,5120,1280,torch.int8 -30720,5120,1280,torch.int8 -32768,5120,1280,torch.int8 -40960,5120,1280,torch.int8 +16,2112,7168,torch.float8_e4m3fn +32,2112,7168,torch.float8_e4m3fn +64,2112,7168,torch.float8_e4m3fn +128,2112,7168,torch.float8_e4m3fn +256,2112,7168,torch.float8_e4m3fn +512,2112,7168,torch.float8_e4m3fn +1024,2112,7168,torch.float8_e4m3fn +2048,2112,7168,torch.float8_e4m3fn +4096,2112,7168,torch.float8_e4m3fn +8192,2112,7168,torch.float8_e4m3fn +16384,2112,7168,torch.float8_e4m3fn +32768,2112,7168,torch.float8_e4m3fn +16,24576,1536,torch.float8_e4m3fn +32,24576,1536,torch.float8_e4m3fn +64,24576,1536,torch.float8_e4m3fn +128,24576,1536,torch.float8_e4m3fn +256,24576,1536,torch.float8_e4m3fn +512,24576,1536,torch.float8_e4m3fn +1024,24576,1536,torch.float8_e4m3fn +2048,24576,1536,torch.float8_e4m3fn +4096,24576,1536,torch.float8_e4m3fn +8192,24576,1536,torch.float8_e4m3fn +16384,24576,1536,torch.float8_e4m3fn +32768,24576,1536,torch.float8_e4m3fn +16,32768,512,torch.float8_e4m3fn +32,32768,512,torch.float8_e4m3fn +64,32768,512,torch.float8_e4m3fn +128,32768,512,torch.float8_e4m3fn +256,32768,512,torch.float8_e4m3fn +512,32768,512,torch.float8_e4m3fn +1024,32768,512,torch.float8_e4m3fn +2048,32768,512,torch.float8_e4m3fn +4096,32768,512,torch.float8_e4m3fn +8192,32768,512,torch.float8_e4m3fn +16384,32768,512,torch.float8_e4m3fn +32768,32768,512,torch.float8_e4m3fn +16,7168,16384,torch.float8_e4m3fn +32,7168,16384,torch.float8_e4m3fn +64,7168,16384,torch.float8_e4m3fn +128,7168,16384,torch.float8_e4m3fn +256,7168,16384,torch.float8_e4m3fn +512,7168,16384,torch.float8_e4m3fn +1024,7168,16384,torch.float8_e4m3fn +2048,7168,16384,torch.float8_e4m3fn +4096,7168,16384,torch.float8_e4m3fn +8192,7168,16384,torch.float8_e4m3fn +16384,7168,16384,torch.float8_e4m3fn +32768,7168,16384,torch.float8_e4m3fn + diff --git a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py index 7562f31c31..d6cc0ae675 100644 --- a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py +++ b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py @@ -3,12 +3,15 @@ """gfx1250 (WMMA) backend for the FlyDSL a8w8 bpreshuffle GEMM. -aiter.gemm_a8w8_bpreshuffle routes here when its tuned kernelName starts with -``flydsl_bpreshuffle_wmma_`` (gfx1250 has no MFMA preshuffle kernel). Runs the -vendored gemm_fp8fp4_gfx1250 WMMA kernel in ptpc scale mode: C = (A*sa) @ (B*sb)^T -with fp32 per-token sa[M] / per-channel sb[N] applied in the epilogue. N/K are not -padded (must divide the tile); M is padded to tile_m when ragged, since the kernel -reads a full tile_m rows per workgroup (M=1 would otherwise read past A/sa -> NaN). +aiter.gemm_a8w8_bpreshuffle's FlyDSL path runs here on gfx1250 (no MFMA preshuffle +kernel); the tuned kernelName (prefix ``flydsl_bpreshuffle_wmma_``) encodes the tile +config. Computes C = (A*sa) @ (B*sb)^T via the vendored gemm_fp8fp4_gfx1250 WMMA +kernel, fp32 per-token sa[M] / per-channel sb[N] applied in the epilogue. + +N/K must divide the tile; M may be ragged (no host padding) — the kernel clips +loads/stores to the runtime M via hardware OOB, predicating split-k's atomic add +per-lane on row < M. A/C may be strided (lda/ldc passed at runtime, no copy) when +the inner dim is unit-stride; B is preshuffled into its own contiguous buffer. """ from __future__ import annotations @@ -26,6 +29,7 @@ _WMMA_K = 128 _SUPPORTED_NUM_BUFFERS = (2, 3, 4) _OUT_DTYPE_NAME = {torch.bfloat16: "bf16", torch.float16: "f16"} +_MAX_SPLIT_K = 4 def _lazy_import(): @@ -54,10 +58,6 @@ def _as_1d_fp32(scale: Tensor, length: int, name: str) -> Tensor: return flat.contiguous() -def _to_uint8(t: Tensor) -> Tensor: - return t.contiguous().view(torch.uint8).view(-1) - - def run_preshuffle_gemm_a8_gfx1250( XQ: Tensor, WQ: Tensor, @@ -113,33 +113,36 @@ def run_preshuffle_gemm_a8_gfx1250( cluster_m = max(1, int(cluster_m)) cluster_n = max(1, int(cluster_n)) - accumulate_fp32 = split_k > 1 - kernel_out_dtype = "f32" if accumulate_fp32 else out_dtype + if split_k > _MAX_SPLIT_K: + raise RuntimeError( + f"[FlyDSL gfx1250] split_k={split_k} exceeds the bf16/f16 atomic-add " + f"precision cap of {_MAX_SPLIT_K}" + ) - # Pipeline depth needs >= 1 K tile per buffer (per split-k chunk). - num_k_tiles = (K // split_k) // tile_k - nb = max(2, min(int(num_buffers), num_k_tiles)) + # Validate (tuned names always pass); fail loudly rather than silently clamp. + nb = int(num_buffers) if nb not in _SUPPORTED_NUM_BUFFERS: - nb = max(b for b in _SUPPORTED_NUM_BUFFERS if b <= nb) + raise RuntimeError( + f"[FlyDSL gfx1250] num_buffers must be one of {_SUPPORTED_NUM_BUFFERS}, " + f"got {nb}" + ) + if K % (split_k * tile_k) != 0: + raise RuntimeError( + f"[FlyDSL gfx1250] K={K} must be divisible by split_k*tile_k=" + f"{split_k}*{tile_k}={split_k * tile_k}" + ) + # Each split-k chunk must hold >= num_buffers K-tiles to fill the pipeline. + num_k_tiles = (K // split_k) // tile_k + if num_k_tiles < nb: + raise RuntimeError( + f"[FlyDSL gfx1250] {nb}-buffer pipeline needs >= {nb} K-tiles per " + f"split-k chunk, got {num_k_tiles} (K={K}, split_k={split_k}, tile_k={tile_k})" + ) sa = _as_1d_fp32(x_scale, M, "x_scale") sb = _as_1d_fp32(w_scale, N, "w_scale") - # M padded to tile_m when ragged (kernel reads a full tile_m rows/wg). - padded_m = ((M + tile_m - 1) // tile_m) * tile_m - if padded_m == M: - a_dev = XQ.contiguous() - sa_dev = sa - else: - a_dev = torch.zeros((padded_m, K), dtype=XQ.dtype, device=XQ.device) - a_dev[:M] = XQ - sa_dev = torch.ones((padded_m,), dtype=torch.float32, device=sa.device) - sa_dev[:M] = sa - - b_dev = WQ.contiguous() - exe = _compile_ptpc_gemm( - M=padded_m, N=N, K=K, data_format="fp8", @@ -152,40 +155,37 @@ def run_preshuffle_gemm_a8_gfx1250( waves_per_eu=(None if waves_per_eu <= 0 else waves_per_eu), cluster_m=cluster_m, cluster_n=cluster_n, - out_dtype=kernel_out_dtype, + out_dtype=out_dtype, split_k=split_k, ) - if accumulate_fp32: - # fp32 atomic-accumulation scratch (zeroed; narrowed into Out below). - out_buf = torch.zeros((padded_m, N), dtype=torch.float32, device=Out.device) - elif padded_m == M: - out_buf = Out.contiguous() - else: - out_buf = torch.empty((padded_m, N), dtype=Out.dtype, device=Out.device) + lda = XQ.stride(0) + ldc = Out.stride(0) + if split_k > 1: + Out.zero_() # split-k atomic-accumulates into Out - stream = _fx.Stream(torch.cuda.current_stream(device=a_dev.device)) + stream = _fx.Stream(torch.cuda.current_stream(device=XQ.device)) _run_compiled( exe, - out_buf.view(-1), - _to_uint8(a_dev), - _to_uint8(b_dev), - sa_dev.contiguous().view(-1), - sb.contiguous().view(-1), - padded_m, + Out, + XQ.view(torch.uint8), + WQ.view(torch.uint8), + sa.view(-1), + sb.view(-1), + M, N, + lda, + ldc, stream, ) - - if out_buf.data_ptr() != Out.data_ptr(): - Out.copy_(out_buf[:M]) return Out -# flydsl_bpreshuffle_wmma_t{tm}x{tn}x{tk}_nb{nb}_sk{sk}_cm{cm}_cn{cn} +# flydsl_bpreshuffle_wmma_t{tm}x{tn}x{tk}_mw{mw}_nw{nw}_nb{nb}_sk{sk}_cm{cm}_cn{cn} _KERNEL_NAME_RE = re.compile( r"^flydsl_bpreshuffle_wmma_" r"t(?P\d+)x(?P\d+)x(?P\d+)_" + r"mw(?P\d+)_nw(?P\d+)_" r"nb(?P\d+)_sk(?P\d+)_" r"cm(?P\d+)_cn(?P\d+)$" ) @@ -200,10 +200,13 @@ def wmma_kernel_name( split_k: int, cluster_m: int, cluster_n: int, + m_warp: int, + n_warp: int, ) -> str: return ( f"flydsl_bpreshuffle_wmma_t{tile_m}x{tile_n}x{tile_k}_" - f"nb{num_buffers}_sk{split_k}_cm{cluster_m}_cn{cluster_n}" + f"mw{m_warp}_nw{n_warp}_nb{num_buffers}_sk{split_k}_" + f"cm{cluster_m}_cn{cluster_n}" ) @@ -238,4 +241,6 @@ def run_gemm_a8w8_bpreshuffle_gfx1250( split_k=cfg["split_k"], cluster_m=cfg["cluster_m"], cluster_n=cfg["cluster_n"], + m_warp=cfg["m_warp"], + n_warp=cfg["n_warp"], ) diff --git a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py index 13a4b12865..940453f2c7 100644 --- a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py +++ b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py @@ -6,26 +6,49 @@ gfx1250 has no MFMA preshuffle kernel; it runs the vendored gemm_fp8fp4_gfx1250 WMMA kernel (ptpc) via ``bpreshuffle_gemm_gfx1250``. This is the WMMA counterpart of ``flydsl_gemm_a8w8_bpreshuffle_common`` (which serves gfx942/gfx950 MFMA), with -its own perf knobs — num_buffers, split_k, cluster — and kernelName format. +its own perf knobs (num_buffers and split_k) and kernelName format. Cluster is +kept in the candidate schema/name but fixed to (1, 1). """ from __future__ import annotations from dataclasses import dataclass -from itertools import product from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import wmma_kernel_name - -WMMA = 16 -WARP = 2 # m_warp = n_warp = 2 -> tile_m/n must be multiples of 32 -LDS_BYTES = 320 * 1024 - -_TILE_M = (32, 64, 256) -_TILE_N = (32, 64, 256) -_TILE_K = (128, 256) -_NUM_BUFFERS = (2, 3, 4) -_SPLIT_K = (1, 2, 4, 8, 16) -_CLUSTER = ((1, 1), (2, 2), (2, 4), (4, 2)) # cluster_m * cluster_n <= 16 +from aiter.ops.flydsl.utils import get_shared_memory_per_block + +WMMA = 16 # WMMA M/N tile granularity +WARP = 2 # default m_warp / n_warp for WmmaKernelInstance +LDS_BYTES = get_shared_memory_per_block(fallback_gfx="gfx1250") +_MAX_WARP_TILE = 128 + +# Mirror the ptpc fp8 LDS layout +_LDS_PAD_A_BYTES = 16 +_LDS_PAD_D_BYTES = 16 +_ELEM_BYTES_D = 2 # bf16 / f16 output + +# Columns: (tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers), grouped by M regime +# fmt: off +_CURATED_INSTANCES = ( + # small M (decode / token-gen): thin tile_m, wide tile_n, deep tile_k + ( 16, 64, 256, 1, 2, 4), ( 16, 64, 512, 1, 2, 4), ( 16, 96, 256, 1, 2, 4), + ( 16, 128, 256, 1, 2, 4), ( 16, 128, 512, 1, 2, 4), ( 16, 192, 256, 1, 2, 4), + ( 16, 256, 256, 1, 4, 4), ( 16, 512, 128, 1, 4, 4), + # M=32 + ( 32, 64, 256, 2, 2, 4), ( 32, 64, 512, 2, 2, 4), ( 32, 128, 256, 2, 2, 4), + ( 32, 192, 256, 2, 2, 4), ( 32, 256, 256, 2, 4, 4), + # M=64 + ( 64, 64, 256, 2, 2, 4), ( 64, 128, 128, 2, 2, 4), ( 64, 192, 128, 2, 2, 4), + ( 64, 256, 128, 2, 4, 4), ( 64, 512, 128, 1, 4, 3), + # M=128 + (128, 128, 128, 2, 2, 4), (128, 192, 128, 2, 2, 4), (128, 256, 128, 2, 4, 4), + (128, 512, 128, 2, 4, 3), + # large M (compute bound): big square tiles, shallow tile_k + (256, 64, 128, 2, 1, 4), (256, 128, 128, 2, 2, 4), (256, 192, 128, 2, 2, 4), + (256, 256, 128, 2, 2, 4), (256, 256, 128, 4, 4, 3), (256, 512, 128, 2, 4, 3), +) +# fmt: on +_SPLIT_K = (1,) @dataclass @@ -37,6 +60,8 @@ class WmmaKernelInstance: split_k: int = 1 cluster_m: int = 1 cluster_n: int = 1 + m_warp: int = WARP + n_warp: int = WARP @property def name(self) -> str: @@ -48,61 +73,80 @@ def name(self) -> str: split_k=self.split_k, cluster_m=self.cluster_m, cluster_n=self.cluster_n, + m_warp=self.m_warp, + n_warp=self.n_warp, ) -def _tile_valid(tm: int, tn: int, tk: int) -> bool: - return tm % (WARP * WMMA) == 0 and tn % (WARP * WMMA) == 0 and tk % 128 == 0 +def _tile_valid(tm: int, tn: int, tk: int, mw: int, nw: int) -> bool: + # Each warp tile must be a multiple of WMMA (16) and <= _MAX_WARP_TILE (VGPR + # budget); tk a multiple of 128; block_threads = mw*nw*32 <= 1024. + return ( + tm % (mw * WMMA) == 0 + and tn % (nw * WMMA) == 0 + and tm // mw <= _MAX_WARP_TILE + and tn // nw <= _MAX_WARP_TILE + and tk % 128 == 0 + and mw * nw <= 32 + ) + + +def _align_up(value: int, align: int) -> int: + return (value + align - 1) // align * align def kernel_instance_estimated_lds_bytes(ki: WmmaKernelInstance) -> int: - return (ki.tile_m * ki.tile_k + ki.tile_n * ki.tile_k) * ki.num_buffers + """LDS bytes the ptpc fp8 WMMA kernel allocates for ``ki`` (must not under-estimate: + an overflowing tile would pass the filter and fault at launch). + Per-stage arena: A pool (rows padded by LDS_PAD_A_BYTES) + 16-aligned B pool, the + stage 128- then 1024-aligned, times num_buffers. The split_k==1 epilogue also needs + a TDM-store D buffer that can exceed the arena for small tiles, so take the max. + """ + lds_a_data = ki.tile_m * (ki.tile_k + _LDS_PAD_A_BYTES) + lds_b_data = ki.tile_n * ki.tile_k + stage_bytes = _align_up(lds_a_data, 16) + lds_b_data + stage_pitch = _align_up(_align_up(stage_bytes, 128), 1024) + arena_bytes = stage_pitch * ki.num_buffers -def max_lds_bytes_for_tune() -> int: - return LDS_BYTES + if ki.split_k == 1: # split_k>1 uses the buffer/atomic store, no LDS D buffer + warp_tile_m = ki.tile_m // ki.m_warp + warp_tile_n = ki.tile_n // ki.n_warp + d_row_stride = warp_tile_n * _ELEM_BYTES_D + _LDS_PAD_D_BYTES + total_d_bytes = (ki.m_warp * ki.n_warp) * warp_tile_m * d_row_stride + return max(arena_bytes, total_d_bytes) + return arena_bytes def _build_kernels_list(): kl = {} idx = 0 - for nb, sk, (cm, cn), tm, tn, tk in product( - _NUM_BUFFERS, _SPLIT_K, _CLUSTER, _TILE_M, _TILE_N, _TILE_K - ): - if not _tile_valid(tm, tn, tk): - continue - ki = WmmaKernelInstance(tm, tn, tk, nb, sk, cm, cn) - if kernel_instance_estimated_lds_bytes(ki) > LDS_BYTES: - continue - kl[idx] = ki - idx += 1 + for tm, tn, tk, mw, nw, nb in _CURATED_INSTANCES: + assert _tile_valid( # an invalid curated entry is a typo -- fail loudly + tm, tn, tk, mw, nw + ), f"invalid curated instance: tile=({tm},{tn},{tk}) warp=({mw},{nw})" + for sk in _SPLIT_K: + ki = WmmaKernelInstance(tm, tn, tk, nb, sk, 1, 1, mw, nw) + if kernel_instance_estimated_lds_bytes(ki) > LDS_BYTES: + continue + kl[idx] = ki + idx += 1 return kl kernels_list: dict[int, WmmaKernelInstance] = _build_kernels_list() -default_kernels_dict = { - (-1): WmmaKernelInstance(128, 128, 128, 2), - (-2): WmmaKernelInstance(32, 64, 128, 2), - (-3): WmmaKernelInstance(64, 128, 128, 2), - (-4): WmmaKernelInstance(128, 256, 128, 2), -} - def kernel_fits_shape(ki: WmmaKernelInstance, M: int, N: int, K: int) -> bool: - """N must divide tile_n (N is never padded); K must divide split_k*tile_k, and - each split-k chunk must hold >= num_buffers K-tiles to fill the pipeline. M is - padded to tile_m, so ragged M is fine without a cluster; a cluster needs an - evenly divisible grid and only pays off for M, N >= 4096. - (LDS is bounded at build time, so it is not re-checked here.) + """N must divide tile_n; K must divide split_k*tile_k with >= num_buffers K-tiles + per chunk. M may be ragged (kernel OOB-clips, no divisibility needed). A cluster + also needs N divisible by cluster_n*tile_n. """ if N % ki.tile_n != 0 or K % (ki.split_k * ki.tile_k) != 0: return False if (K // ki.split_k) // ki.tile_k < ki.num_buffers: return False if ki.cluster_m > 1 or ki.cluster_n > 1: - if M < 4096 or N < 4096: - return False - if M % (ki.cluster_m * ki.tile_m) != 0 or N % (ki.cluster_n * ki.tile_n) != 0: + if N % (ki.cluster_n * ki.tile_n) != 0: return False return True diff --git a/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py b/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py deleted file mode 100644 index bb415faf46..0000000000 --- a/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py +++ /dev/null @@ -1,285 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -"""Auto-tuner for the gfx1250 (WMMA) a8w8 bpreshuffle GEMM. - -FlyDSL-only counterpart of the MFMA tuner in -``csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py``: for each untuned -(M, N, K) it benchmarks the WMMA candidates from -``flydsl_gemm_a8w8_bpreshuffle_wmma_common`` and writes the winner's kernelName -(``flydsl_bpreshuffle_wmma_*``, libtype "flydsl") into the a8w8 bpreshuffle tuned -CSV. The public ``aiter.gemm_a8w8_bpreshuffle`` op then routes to it on gfx1250. - -Usage:: - - python aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py \ - --untune_file aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv \ - --tune_file aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv -""" - -import pandas as pd -import torch - -import aiter -from aiter import dtypes -from aiter.jit.core import AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE, AITER_ROOT_DIR -from aiter.utility.base_tuner import GemmCommonTuner -from aiter.utility.mp_tuner import mp_tuner -from aiter.ops.shuffle import shuffle_weight -from aiter.ops.flydsl.utils import is_flydsl_available -from aiter.ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( - kernels_list, - kernel_fits_shape, -) - -if is_flydsl_available(): - from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import run_preshuffle_gemm_a8_gfx1250 - -_OUT_TORCH = {"bf16": torch.bfloat16, "f16": torch.float16} -_Q_DTYPE_W = str(dtypes.fp8) - - -def run_torch(x, weight, x_scale, w_scale, dtype=torch.bfloat16): - a = x.to(torch.float32) * x_scale.to(torch.float32) - b = weight.to(torch.float32) * w_scale.to(torch.float32) - return (a @ b.t()).to(dtype) - - -def generate_data(m, n, k, seed, out_dtype="bf16", device="cuda"): - torch.manual_seed(seed) - x = torch.randn((m, k), dtype=torch.bfloat16, device=device) - weight = torch.randn((n, k), dtype=torch.bfloat16, device=device) - x, x_scale = aiter.pertoken_quant(x, quant_dtype=dtypes.fp8) - weight, w_scale = aiter.pertoken_quant(weight, quant_dtype=dtypes.fp8) - weight_shuffle = shuffle_weight(weight, layout=(16, 16)) - out = torch.empty(m, n, dtype=_OUT_TORCH[out_dtype], device=device) - return { - "x": x, - "weight": weight, - "weight_shuffle": weight_shuffle, - "x_scale": x_scale, - "w_scale": w_scale, - "out": out, - } - - -def run_gemm(x, weight_shuffle, x_scale, w_scale, out, kernel_id): - ki = kernels_list[kernel_id] - run_preshuffle_gemm_a8_gfx1250( - x, - weight_shuffle, - x_scale, - w_scale, - out, - ki.tile_m, - ki.tile_n, - ki.tile_k, - num_buffers=ki.num_buffers, - split_k=ki.split_k, - cluster_m=ki.cluster_m, - cluster_n=ki.cluster_n, - ) - return out - - -class GemmA8W8BpreShuffleWmmaTuner(GemmCommonTuner): - ARG_DEFAULTS = { - **GemmCommonTuner.ARG_DEFAULTS, - "tune_file": f"{AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE}", - "untune_file": f"{AITER_ROOT_DIR}/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv", - "config_env_name": "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE", - } - - def _clear_op_caches(self): - from aiter.ops import gemm_op_a8w8 as _op - - _op.get_GEMM_config_with_quant_type.cache_clear() - _op._GEMM_QUANT_TYPE_CACHE.clear() - _op._GEMM_QUANT_TYPE_HAS_GFX.clear() - - def _setup_specific_arguments(self): - self.parser.add_argument( - "--out_dtype", - type=str, - default="bf16", - choices=["bf16", "f16"], - help="Output dtype to tune (run once per dtype your model needs)", - ) - - def calculate(self, results, bpes=(1, 1, 2)): - return super().calculate(results, bpes=bpes) - - def getKernelName(self, kernelId, libtype="flydsl"): - ki = kernels_list.get(kernelId) - return ki.name if ki is not None else None - - def result_to_df(self, results): - resultdf = pd.DataFrame(columns=self.columns) - for el in results: - info, time, err_ratio = el - keys, kernelId, splitK, kernelName, libtype = info - if time == self.INVALID_TIME: - kernelName = "None" - elif kernelName == "": - resolved = self.getKernelName(kernelId, libtype) - kernelName = "None" if resolved is None else str(resolved) - tflops, bw = self.calculate(el) - key_dict = dict(zip(self.keys, keys)) - key_dict.update( - { - "libtype": [libtype], - "kernelId": [kernelId], - "splitK": [splitK], - "us": [time], - "kernelName": [kernelName], - "errRatio": [err_ratio], - "tflops": [tflops], - "bw": [bw], - } - ) - temp = pd.DataFrame(key_dict) - resultdf = ( - temp - if resultdf.empty - else pd.concat([resultdf, temp], ignore_index=True) - ) - return resultdf - - def get_wmma_tune_task(self, info_keys, out_dtype, seed, args): - gfx, cu_num, M, N, K, q_dtype_w = info_keys - if not is_flydsl_available(): - return [] - run_keys = ["x", "weight_shuffle", "x_scale", "w_scale", "out"] - ref_keys = ["x", "weight", "x_scale", "w_scale"] - tasks = [] - for i in sorted(kernels_list.keys()): - ki = kernels_list[i] - if not kernel_fits_shape(ki, M, N, K): - continue - info = (info_keys, i, 0, ki.name, "flydsl") - tasks.append( - ( - info, - generate_data, - (M, N, K, seed, out_dtype), - run_gemm, - (run_keys, i), - {"num_warmup": args.warmup, "num_iters": args.iters}, - run_torch, - (ref_keys, _OUT_TORCH[out_dtype]), - {}, - None, - 1e-2, - 1e-2, - None, - None, - ("out",), - ) - ) - return tasks - - def tune(self, untunedf, tunedf, args): - mp_num = args.mp - shape_grouped = args.shape_grouped - errRatio = args.errRatio - cu_num = self.get_cu_num() - gfx = self.get_gfx() - out_dtype = args.out_dtype - task = [] - tasks_data = [] - seed = 0 - already = set() - if tunedf is not None and not tunedf.empty and "q_dtype_w" in tunedf.columns: - sub = tunedf[tunedf["q_dtype_w"] == _Q_DTYPE_W] - already = {(int(r.M), int(r.N), int(r.K)) for r in sub.itertuples()} - for i in range(len(untunedf)): - M = untunedf.loc[i, "M"] - N = untunedf.loc[i, "N"] - K = untunedf.loc[i, "K"] - if ( - "q_dtype_w" in untunedf.columns - and str(untunedf.loc[i, "q_dtype_w"]) != _Q_DTYPE_W - ): - continue - if (int(M), int(N), int(K)) in already: - continue - seed += 1 - prev = len(task) - info_keys = (gfx, cu_num, M, N, K, _Q_DTYPE_W) - task.extend(self.get_wmma_tune_task(info_keys, out_dtype, seed, args)) - tasks_data.append((len(task) - prev, ())) - ret = [] - if task: - ret = mp_tuner( - task, - tasks_data, - mp_num, - False, - shape_grouped, - errRatio, - timeout=args.timeout, - verbose=args.verbose, - ) - return ret - - def run_config(self, args): - from aiter.test_common import run_perftest, checkAllclose - - out_dtype = args.out_dtype - allowed = args.errRatio - results = [] - for i in range(len(self.untunedf)): - row = self.untunedf.iloc[i] - if "q_dtype_w" in row and str(row["q_dtype_w"]) != _Q_DTYPE_W: - continue - M, N, K = int(row["M"]), int(row["N"]), int(row["K"]) - shape_str = f"M{M}_N{N}_K{K}_fp8_{out_dtype}" - try: - d = generate_data(M, N, K, seed=0, out_dtype=out_dtype) - ref = run_torch( - d["x"], - d["weight"], - d["x_scale"], - d["w_scale"], - _OUT_TORCH[out_dtype], - ) - out, us = run_perftest( - aiter.gemm_a8w8_bpreshuffle, - d["x"], - d["weight_shuffle"], - d["x_scale"], - d["w_scale"], - dtype=_OUT_TORCH[out_dtype], - ) - err = checkAllclose(out, ref, msg=f"run_config {shape_str}") - status = "ok" if err <= allowed else f"mismatch:err_ratio={err:.6g}" - results.append({"shape": shape_str, "e2e_us": us, "status": status}) - except Exception as e: # noqa: BLE001 - results.append( - {"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"} - ) - finally: - torch.cuda.empty_cache() - return results - - -if __name__ == "__main__": - key = ["gfx", "cu_num", "M", "N", "K", "q_dtype_w"] - resultList = [ - "libtype", - "kernelId", - "splitK", - "us", - "kernelName", - "tflops", - "bw", - "errRatio", - ] - tuner = GemmA8W8BpreShuffleWmmaTuner( - "GemmA8W8BpreShuffleWmmaTuner", - key=key, - resultList=resultList, - description="Auto-tuner for gfx1250 (WMMA) a8w8 bpreshuffle GEMM", - ) - args = tuner.parse_args() - tuner.run(args, False) diff --git a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py index baa59da3dd..c44424b81e 100644 --- a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py +++ b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py @@ -13,7 +13,7 @@ import flydsl.compiler as flyc import flydsl.expr as fx from flydsl._mlir import ir -from flydsl._mlir.dialects import fly, llvm +from flydsl._mlir.dialects import fly, llvm, scf from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import ( arith, @@ -57,14 +57,29 @@ def _s_prefetch_inst_burst(num_pages: int, page_bytes: int = 4096): _llvm.inline_asm(None, [], "\n".join(lines), "", has_side_effects=True) -# compatible with no early_timeout descriptor -_TDM_HAS_EARLY_TIMEOUT = ( - "early_timeout" in inspect.signature(tdm_ops.make_tensor_descriptor_2d).parameters -) +# Feature-detect the installed flydsl's TDM descriptor builder. Older pinned +# flydsl predates these args; we apply each only when supported and otherwise +# fall back to the vendored OOB-capable builder for the non-tile-aligned-M path. +_TDM_SIG_PARAMS = inspect.signature(tdm_ops.make_tensor_descriptor_2d).parameters +_TDM_HAS_EARLY_TIMEOUT = "early_timeout" in _TDM_SIG_PARAMS +_TDM_HAS_OOB = "oob_outer_bound" in _TDM_SIG_PARAMS + + +def _make_tdm_desc(*, early_timeout=False, oob_outer_bound=None, **kwargs): + """Build a 2D TDM descriptor, transparently across flydsl versions.""" + strides = kwargs.get("strides") + runtime_stride = strides is not None and not isinstance(strides[0], int) + needs_oob = oob_outer_bound is not None + if runtime_stride or (needs_oob and not _TDM_HAS_OOB): + from .tdm_oob import make_tensor_descriptor_2d as _vendored_make_desc -def _make_tdm_desc(*, early_timeout=False, **kwargs): - """Build a TDM descriptor, applying early_timeout only when supports it.""" + return _vendored_make_desc( + early_timeout=early_timeout, oob_outer_bound=oob_outer_bound, **kwargs + ) + + if _TDM_HAS_OOB: + kwargs["oob_outer_bound"] = oob_outer_bound if _TDM_HAS_EARLY_TIMEOUT: kwargs["early_timeout"] = early_timeout return tdm_ops.make_tensor_descriptor_2d(**kwargs) @@ -87,7 +102,6 @@ def compile_fp8fp4_gemm( *, data_format: str = "fp4", scale_mode: str = "mxscale", - M: int = 0, N: int = 0, K: int, tile_m: int = 128, @@ -126,7 +140,9 @@ def compile_fp8fp4_gemm( ptpc: scale_A [M], scale_B [N] fp32 Returns a JitFunction: - launch_fn(arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, M, N, stream) + launch_fn(arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, M, N, lda, ldc, stream) + where lda / ldc are the runtime leading-dim strides (in elements) of A / C. + Pass lda == K and ldc == N for dense (contiguous) tensors. """ if data_format not in ("fp4", "fp8", "a8w4"): raise ValueError( @@ -608,6 +624,8 @@ def kernel_mxscale_gemm( arg_b_scale: fx.Tensor, i32_m: fx.Int32, i32_n: fx.Int32, + i32_lda: fx.Int32, + i32_ldc: fx.Int32, ): # Enable back-to-back WMMA issue (SCHED_MODE bit[4] = DISABLE_VALU_STALL) rocdl.disable_xdl_arb_stall() @@ -689,7 +707,14 @@ def _bvs_prefetch(k_base): return a, b m_idx = fx.Index(i32_m) - n_stride = arith.index(N) + # Leading-dim strides arrive at runtime (strided A / C); the dense path + # passes lda == K and ldc == N, giving byte-identical addressing. A's + # stride is in packed-A elements (== lda for fp8 where PACK_FACTOR_A == 1). + if const_expr(PACK_FACTOR_A == 1): + lda_packed = fx.Index(i32_lda) + else: + lda_packed = fx.Index(i32_lda) / arith.index(PACK_FACTOR_A) + n_stride = fx.Index(i32_ldc) c_nrec = m_idx * n_stride * arith.index(elem_bytes_d) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) c_global_ptr_type = ir.Type.parse("!llvm.ptr<1>") @@ -707,7 +732,7 @@ def make_desc_a(memref, k_base): lds_memref=memref, global_offset=(blk_m, k_packed_off), tensor_shape=(tile_m, packed_tile_k_a), - strides=(K_packed_a, 1), + strides=(lda_packed, 1), tile_shape=(tile_m, packed_tile_k_a), elem_bytes=1, pad_interval=packed_tile_k_a, @@ -716,6 +741,7 @@ def make_desc_a(memref, k_base): workgroup_mask=a_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, + oob_outer_bound=i32_m, ) def make_desc_b(memref, k_base): @@ -744,7 +770,7 @@ def make_desc_a_half(memref, k_base, m_half: int): lds_memref=memref, global_offset=(blk_m + arith.index(row_start), k_packed_off), tensor_shape=(tile_m, packed_tile_k_a), - strides=(K_packed_a, 1), + strides=(lda_packed, 1), tile_shape=(ab_split_a_rows, packed_tile_k_a), elem_bytes=1, pad_interval=packed_tile_k_a, @@ -754,6 +780,7 @@ def make_desc_a_half(memref, k_base, m_half: int): lds_byte_offset=arith.index(row_start * lds_a_stride_bytes), atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, + oob_outer_bound=i32_m, ) def make_desc_b_half(memref, k_base, n_half: int): @@ -1065,8 +1092,11 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): fmtB=0, ) else: - # for compatibility, use the following no-scale wmma instead future - # accs[idx] = rocdl.wmma_f32_16x16x128_fp8_fp8(T.vec(8, T.f32), b_frag, a_frag, accs[idx]) + # PTPC-FP8 needs no per-K scaling. We emit the scaled f8f6f4 op + # with an identity E8M0 scale (0x7F = 2^0 = 1.0) for toolchain + # compatibility; it is numerically equivalent to the dedicated + # no-scale op. Future: switch to the equivalent no-scale wmma: + # accs[idx] = rocdl.wmma_f32_16x16x128_fp8_fp8(T.vec(8, T.f32), b_frag, a_frag, accs[idx]) accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( T.vec(8, T.f32), b_frag, @@ -2351,12 +2381,18 @@ def epilogue_atomic_adds(final_accs, addrs): addr_idx = 0 for acc_idx, vec_base, m_off, wn in _sub_tiles: sub8 = _get_acc_sub8(final_accs, acc_idx, vec_base) - if const_expr(_bf16_out): - addr_idx += _atomic_add_acc_vec8_to_buffer(sub8, addrs[addr_idx]) - else: - addr_idx += _atomic_add_acc_vec8_to_buffer( - sub8, addrs[addr_idx : addr_idx + 2] - ) + n_slots = 1 if _bf16_out else 2 + addr_arg = ( + addrs[addr_idx] if _bf16_out else addrs[addr_idx : addr_idx + 2] + ) + # Atomics use a raw global ptr (no num_records clip), so predicate + # per-lane to skip rows >= M. + row = blk_m + warp_m_base + arith.index(m_off) + lane16 + if_op = scf.IfOp(row < m_idx, [], has_else=False) + with ir.InsertionPoint(if_op.then_block): + _atomic_add_acc_vec8_to_buffer(sub8, addr_arg) + scf.YieldOp([]) + addr_idx += n_slots def grouped_accs_to_row_major(accs_grouped): row_major = [None] * n_accs @@ -2372,8 +2408,16 @@ def finalize_acc_layout(accs_in): def epilogue_load_ptpc_scales(): # PTPC scales: sa[M] per-token (scalar per wm), sb[N] per-channel # (8 contiguous N cols per wn). Both fp32, constant along K. - sa_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) - sb_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, max_size=False) + # The scale memrefs are dynamically shaped, so max_size=False would fall + # back to a max-sized descriptor and disable hardware OOB. Derive + # num_records from runtime M / compile-time N (fp32 = 4 bytes) so the + # partial last M-tile clips rows >= M (and cols >= N) to 0. + sa_rsrc = buffer_ops.create_buffer_resource( + arg_a_scale, num_records_bytes=m_idx * arith.index(4) + ) + sb_rsrc = buffer_ops.create_buffer_resource( + arg_b_scale, num_records_bytes=N * 4 + ) sa = [] for wm in range_constexpr(wmma_m_rep): row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 @@ -2557,7 +2601,7 @@ def _l2_prefetch(k_base): lds_memref=d_lds_base_ptr, global_offset=(blk_m + warp_m_off_sgpr, blk_n + warp_n_off_sgpr), tensor_shape=(warp_tile_m, warp_tile_n), - strides=(N, 1), + strides=(n_stride, 1), tile_shape=(warp_tile_m, warp_tile_n), elem_bytes=elem_bytes_d, pad_interval=warp_tile_n, @@ -2565,6 +2609,7 @@ def _l2_prefetch(k_base): num_warps=1, lds_byte_offset=d_warp_off_sgpr, for_store=True, + oob_outer_bound=i32_m, ) # TDM descriptor lane layout: dgroup0 = [predicate, lds_addr, addr_lo, addr_hi]. @@ -3121,7 +3166,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): _ptpc_sa, _ptpc_sb = _ptpc_scale_box[0] accs = epilogue_apply_ptpc_scale(accs, _ptpc_sa, _ptpc_sb) - if const_expr(use_tdm_store): + def _emit_tdm_store(): if const_expr(d_need_epilogue_fence): _pipeline_fence(outstanding=0) rocdl.sched_barrier(0) @@ -3129,7 +3174,8 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): rocdl.s_wait_dscnt(0) tdm_ops.tensor_store_2d(d_desc) tdm_ops.tensor_wait(0) - else: + + def _emit_buffer_store(): rocdl.sched_barrier(0) if const_expr(epi_addrs_box[0] is None): epi_addrs_box[0] = epilogue_prepare_addrs() @@ -3138,6 +3184,21 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): else: epilogue_stores(accs, epi_addrs_box[0]) + if const_expr(use_tdm_store): + # Full M-tiles take the fast TDM store; the partial last M-tile + # (rows >= M) falls back to the buffer store, whose num_records clip + # drops the OOB rows. + full_tile = (blk_m + arith.index(tile_m)) <= m_idx + if_op = scf.IfOp(full_tile, [], has_else=True) + with ir.InsertionPoint(if_op.then_block): + _emit_tdm_store() + scf.YieldOp([]) + with ir.InsertionPoint(if_op.else_block): + _emit_buffer_store() + scf.YieldOp([]) + else: + _emit_buffer_store() + cache_tag = ( data_format, scale_mode, @@ -3175,6 +3236,8 @@ def launch_mxscale_gemm( arg_b_scale: fx.Tensor, i32_m: fx.Int32, i32_n: fx.Int32, + i32_lda: fx.Int32, + i32_ldc: fx.Int32, stream: fx.Stream, ): _ = cache_tag @@ -3187,6 +3250,11 @@ def launch_mxscale_gemm( gy = (i32_n + (tile_n - 1)) // tile_n gz = split_k + if const_expr(use_cluster): + # Cluster launch needs a cluster-divisible grid; the extra M-tiles + # are fully OOB (rows >= M) and the kernel clips them. + gx = ((gx + (cluster_m - 1)) // cluster_m) * cluster_m + cluster_arg = (cluster_m, cluster_n, 1) if use_cluster else None kernel_mxscale_gemm( arg_c, @@ -3196,6 +3264,8 @@ def launch_mxscale_gemm( arg_b_scale, i32_m, i32_n, + i32_lda, + i32_ldc, value_attrs={ "rocdl.waves_per_eu": effective_waves_per_eu, "rocdl.cluster_dims": ( @@ -3236,7 +3306,6 @@ def compile_a8w4_gemm(**kw): def compile_ptpc_gemm( *, - M: int = 0, N: int = 0, K: int, data_format: str = "fp8", @@ -3275,7 +3344,6 @@ def compile_ptpc_gemm( fp8_schedule="auto", scale_load_path="tdm", use_tdm_store=(split_k == 1), - M=M, N=N, K=K, tile_m=tile_m, diff --git a/aiter/ops/flydsl/kernels/tdm_oob.py b/aiter/ops/flydsl/kernels/tdm_oob.py new file mode 100644 index 0000000000..4e1e3ad867 --- /dev/null +++ b/aiter/ops/flydsl/kernels/tdm_oob.py @@ -0,0 +1,313 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Vendored OOB-capable TDM 2D descriptor builder for gfx1250. + +This is a faithful copy of ``flydsl.expr.rocdl.tdm_ops.make_tensor_descriptor_2d`` +as of the FlyDSL "add M out-of-bounds support" change, carried here so the +non-tile-aligned-M (OOB) GEMM path works against the *older* flydsl this aiter +build pins, whose ``make_tensor_descriptor_2d`` predates the ``oob_outer_bound`` +argument. + +The kernel only routes through this fallback when the installed flydsl lacks +native ``oob_outer_bound`` support (see ``_make_tdm_desc`` in +``gemm_fp8fp4_gfx1250``); when flydsl has it, the native builder is used. + +To stay robust across flydsl internal-layout changes, every low-level symbol is +sourced from the installed ``tdm_ops`` module namespace (i.e. whatever that +module successfully imported) rather than re-imported from private paths. The +only behavioural delta vs. the pinned builder is the ``oob_outer_bound`` branch +that computes a runtime ``tensor_dim1``; with ``oob_outer_bound=None`` the output +is byte-identical to the original path. +""" + +from __future__ import annotations + +import math +from typing import Tuple, Union + +from flydsl.expr.rocdl import tdm_ops as _tdm + +# Reuse whatever the installed tdm_ops bound — keeps us in lock-step with the +# pinned flydsl's lower-level primitives instead of guessing private paths. +ir = _tdm.ir +std_arith = _tdm.std_arith +llvm_dialect = _tdm.llvm_dialect +memref_dialect = _tdm.memref_dialect +arith = _tdm.arith +vector = _tdm.vector +_raw = _tdm._raw +T = _tdm.T +_ArithValue = _tdm._ArithValue +compute_warp_distribution = _tdm.compute_warp_distribution +compute_padding_encoding = _tdm.compute_padding_encoding +TDMDescriptor2D = _tdm.TDMDescriptor2D + + +def make_tensor_descriptor_2d( + global_ptr, + lds_memref, + global_offset: Tuple, + tensor_shape: Tuple[int, int], + strides: Tuple[int, int], + tile_shape: Tuple[int, int], + elem_bytes: int = 2, + pad_interval: int = 0, + pad_amount: int = 0, + num_warps: int = 1, + cache_policy: int = 0, + pred: int = 1, + workgroup_mask: Union[int, "ir.Value"] = 0, + lds_byte_offset=None, + for_store: bool = False, + atomic_barrier_enable: bool = False, + early_timeout: bool = False, + oob_outer_bound=None, +) -> "TDMDescriptor2D": + """Build a 2D TDM descriptor (vendored, OOB-capable). + + See ``flydsl.expr.rocdl.tdm_ops.make_tensor_descriptor_2d`` for the full + argument reference. ``oob_outer_bound`` is the runtime outer-dim global + extent (real M for a row-major A/C); when given, ``tensor_dim1`` is set to + the tile-start-relative remaining extent + ``max(0, oob_outer_bound - (outer_off + warp_off_outer))`` while + ``tile_dim1`` stays the full per-warp tile, so the partial last tile exceeds + the tensor bound and the hardware OOB-handles the overhang (fault-safe load, + zero-fill in LDS). Accepts a Python int or an i32/index ir.Value. ``None`` + keeps ``tensor_dim1 == tile_dim1`` (OOB off) — byte-identical to the + non-OOB path. + """ + from flydsl._mlir.dialects import fly as _fly_d + + outer_stride, inner_stride = strides + outer_tile, inner_tile = tile_shape + outer_off, inner_off = global_offset + + # The outer (leading-dim) stride may be a compile-time int or a runtime + # i32/index ir.Value (strided A/C, e.g. a row-slice whose row pitch exceeds + # the logical inner extent). Normalise to an index value for address math and + # an i32 value for the descriptor's stride field (sgpr5). + if isinstance(outer_stride, int): + outer_stride_idx = arith.index(outer_stride) + outer_stride_is_runtime = False + else: + os_val = ( + outer_stride.ir_value() + if hasattr(outer_stride, "ir_value") + else outer_stride + ) + if not isinstance(os_val, ir.Value): + raise TypeError( + f"outer stride must be int or i32/index ir.Value, " + f"got {type(outer_stride).__name__}" + ) + if isinstance(os_val.type, ir.IndexType): + # Wrap raw ir.Value so it supports the _ArithValue ops below (*, cast). + outer_stride_idx = _ArithValue(os_val) + elif isinstance(os_val.type, ir.IntegerType) and os_val.type.width == 32: + outer_stride_idx = arith.index_cast(T.index, os_val) + else: + raise TypeError( + f"outer stride ir.Value must be index or i32, got {os_val.type}" + ) + outer_stride_is_runtime = True + + # -- Warp distribution -- + warps_per_dim, block_per_warp = compute_warp_distribution( + [outer_tile, inner_tile], + num_warps, + ) + bpw_outer, bpw_inner = block_per_warp + warps_dim0 = warps_per_dim[0] + + if num_warps > 1: + # Auto-acquire SGPR wave_id via hardware register (TTMP8[29:25]). + # This keeps the entire descriptor address chain in SALU, + from flydsl.expr import rocdl as _rocdl_ext + + _wid_i32 = _rocdl_ext.wave_id() + wave_id = arith.index_cast(T.index, _wid_i32) + warp_coord_outer = wave_id % arith.index(warps_dim0) + warp_coord_inner = wave_id / arith.index(warps_dim0) + warp_off_outer = warp_coord_outer * arith.index(bpw_outer) + warp_off_inner = warp_coord_inner * arith.index(bpw_inner) + else: + warp_off_outer = arith.index(0) + warp_off_inner = arith.index(0) + + # -- Global address (byte address for descriptor) -- + glb_ptr_type = ir.Type.parse("!llvm.ptr<1>") + i64 = ir.IntegerType.get_signless(64) + a_raw = global_ptr.__extract_to_ir_values__()[0] + glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) + glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) + glb_elem_off = (outer_off + warp_off_outer) * outer_stride_idx + ( + inner_off + warp_off_inner + ) * arith.index(inner_stride) + glb_byte_off = glb_elem_off * arith.index(elem_bytes) + glb_byte_off_i64 = arith.index_cast(T.i64, glb_byte_off) + glb_addr_i64 = glb_base_i64 + glb_byte_off_i64 + + # -- LDS address (byte address within shared memory) -- + lds_base_idx = _ArithValue( + memref_dialect.extract_aligned_pointer_as_index(lds_memref) + ) + # Compute padded LDS stride (elements) for the outer dim + if pad_interval > 0 and pad_amount > 0: + lds_inner_stride = inner_tile + pad_amount # padded row width + else: + lds_inner_stride = inner_tile + lds_warp_elem_off = warp_off_outer * arith.index(lds_inner_stride) + warp_off_inner + lds_warp_byte_off = lds_warp_elem_off * arith.index(elem_bytes) + lds_total_off = lds_base_idx + lds_warp_byte_off + if lds_byte_offset is not None: + lds_total_off = lds_total_off + lds_byte_offset + lds_addr_i32 = arith.index_cast(T.i32, lds_total_off) + + # ================================================================ + # GROUP0 (vector<4xi32>): pred, lds_addr, global_addr_lo/hi + # ================================================================ + g0_s0 = arith.constant(pred, type=T.i32) + g0_s1 = lds_addr_i32 + i32 = ir.IntegerType.get_signless(32) + g0_s2 = _ArithValue(std_arith.TruncIOp(i32, _raw(glb_addr_i64)).result) + hi_raw = _ArithValue(_raw(glb_addr_i64)).shrui(arith.constant(32, type=T.i64)) + g0_s3 = _ArithValue(std_arith.TruncIOp(i32, _raw(hi_raw)).result) | arith.constant( + 1 << 31, type=T.i32 + ) # type field = 2 in [31:30] + dgroup0 = vector.from_elements(T.vec(4, T.i32), [g0_s0, g0_s1, g0_s2, g0_s3]) + + # ================================================================ + # GROUP1 (vector<8xi32>): config + tensor dims + strides + tile + # ================================================================ + # Descriptor dim ordering: dim0=innermost, dim1=outermost + tdim0 = bpw_inner # innermost extent per warp + tdim1 = bpw_outer # outermost extent per warp + tile_d0 = bpw_inner # block dim0 per warp + tile_d1 = bpw_outer # block dim1 per warp + + # Padding can be applied to the LDS address when copying from memory to LDS, + # but not when copying from LDS to memory + # (there is no "de-padding" operation; padding is ignored). + if for_store and pad_interval > 0 and pad_amount > 0: + tile_d0 += pad_amount + pad_interval = 0 + pad_amount = 0 + + # stride_dim0 in descriptor = outermost stride in elements + stride0 = outer_stride + + # data_size = log2(elem_bytes) + data_size_code = int(math.log2(elem_bytes)) + + # Padding encoding + if pad_interval > 0 and pad_amount > 0: + elem_bits = elem_bytes * 8 + enc_interval, enc_amount = compute_padding_encoding( + pad_interval, pad_amount, elem_bits + ) + pad_enable = 1 + else: + enc_interval, enc_amount = 0, 0 + pad_enable = 0 + + # sgpr0: config bitfields + _abe = 1 if atomic_barrier_enable else 0 + _early_timeout = 1 if early_timeout else 0 + g1_s0_upper = ( + (data_size_code << 16) # data_size [17:16] + | (_abe << 18) # atomic_barrier_enable + | (0 << 19) # iterate_enable + | (pad_enable << 20) # pad_enable + | (_early_timeout << 21) # early_timeout + | (enc_interval << 22) # pad_interval [24:22] + | (enc_amount << 25) # pad_amount [31:25] + ) + + if isinstance(workgroup_mask, int): + g1_s0_val = (workgroup_mask & 0xFFFF) | g1_s0_upper + g1_s0 = arith.constant(g1_s0_val, type=T.i32) + else: + upper_const = arith.constant(g1_s0_upper, type=T.i32) + mask_i32 = arith.andi(workgroup_mask, arith.constant(0xFFFF, type=T.i32)) + g1_s0 = arith.ori(upper_const, mask_i32) + + # sgpr1: atomic_barrier_addr[15:0]=0 | tensor_dim0_lo[31:16] + g1_s1 = arith.constant((tdim0 & 0xFFFF) << 16, type=T.i32) + + if oob_outer_bound is None: + # Compile-time tensor_dim1 == tile extent: OOB checking off. + # sgpr2: tensor_dim0_hi[15:0] | tensor_dim1_lo[31:16] + g1_s2 = arith.constant( + ((tdim0 >> 16) & 0xFFFF) | ((tdim1 & 0xFFFF) << 16), + type=T.i32, + ) + # sgpr3: tensor_dim1_hi[15:0] | tile_dim0[31:16] + g1_s3 = arith.constant( + ((tdim1 >> 16) & 0xFFFF) | (tile_d0 << 16), + type=T.i32, + ) + else: + # Runtime tensor_dim1 = max(0, oob_outer_bound - (outer_off + warp_off_outer)), + # tile-start-relative (the descriptor's global address already includes the + # tile/warp start). tile_dim1 (sgpr4) stays the full per-warp tile, so the + # partial last tile exceeds the tensor bound and the HW OOB-handles the + # overhang. tensor_dim0 (innermost) and the tile dims stay compile-time. + if isinstance(oob_outer_bound, int): + ob_i32 = arith.constant(oob_outer_bound, type=T.i32) + else: + ob_i32 = ( + oob_outer_bound.ir_value() + if hasattr(oob_outer_bound, "ir_value") + else oob_outer_bound + ) + if not isinstance(ob_i32, ir.Value): + raise TypeError( + f"oob_outer_bound must be int or i32/index ir.Value, " + f"got {type(oob_outer_bound).__name__}" + ) + if isinstance(ob_i32.type, ir.IndexType): + ob_i32 = arith.index_cast(T.i32, ob_i32) + elif not ( + isinstance(ob_i32.type, ir.IntegerType) and ob_i32.type.width == 32 + ): + raise TypeError( + f"oob_outer_bound ir.Value must be index or i32, got {ob_i32.type}" + ) + start_i32 = arith.index_cast(T.i32, outer_off + warp_off_outer) + tdim1_rt = arith.maxsi( + arith.subi(ob_i32, start_i32), arith.constant(0, type=T.i32) + ) + c16 = arith.constant(16, type=T.i32) + c_mask16 = arith.constant(0xFFFF, type=T.i32) + # sgpr2: tensor_dim0_hi[15:0] (const) | tensor_dim1_lo[31:16] (runtime) + g1_s2 = arith.ori( + arith.constant((tdim0 >> 16) & 0xFFFF, type=T.i32), + arith.shli(arith.andi(tdim1_rt, c_mask16), c16), + ) + # sgpr3: tensor_dim1_hi[15:0] (runtime) | tile_dim0[31:16] (const) + g1_s3 = arith.ori( + arith.andi(arith.shrui(tdim1_rt, c16), c_mask16), + arith.constant(tile_d0 << 16, type=T.i32), + ) + + # sgpr4: tile_dim1[15:0] | tile_dim2[31:16]=0 (always the full per-warp tile) + g1_s4 = arith.constant(tile_d1 & 0xFFFF, type=T.i32) + + # sgpr5: tensor_dim0_stride (low 32 bits) — stride of outermost dim + if outer_stride_is_runtime: + # Runtime leading-dim stride: truncate the index to i32 (strides < 2^31). + g1_s5 = arith.index_cast(T.i32, outer_stride_idx) + else: + g1_s5 = arith.constant(stride0 & 0xFFFFFFFF, type=T.i32) + + # sgpr6-7: for 2D, no higher-dim strides + g1_s6 = arith.constant(0, type=T.i32) + g1_s7 = arith.constant(0, type=T.i32) + + dgroup1 = vector.from_elements( + T.vec(8, T.i32), + [g1_s0, g1_s1, g1_s2, g1_s3, g1_s4, g1_s5, g1_s6, g1_s7], + ) + + return TDMDescriptor2D(dgroup0=dgroup0, dgroup1=dgroup1) diff --git a/aiter/ops/flydsl/utils.py b/aiter/ops/flydsl/utils.py index b3a1a49f55..ba6a60b00e 100644 --- a/aiter/ops/flydsl/utils.py +++ b/aiter/ops/flydsl/utils.py @@ -17,6 +17,8 @@ def addressable_lds_bytes_for_gfx(gfx: str) -> int: return _FALLBACK_MAX_LDS_BYTES if g.startswith("gfx950"): return 163840 + if g.startswith("gfx1250"): + return 327680 if g.startswith("gfx7") or g.startswith("gfx8"): return 32768 return 65536 diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 3f9627dcad..8192f715f1 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -128,7 +128,8 @@ def gemm_a8w8_bpreshuffle_flydsl( config: dict, ) -> Tensor: kernel_name = str(config.get("kernelName", "")) - if kernel_name.startswith("flydsl_bpreshuffle_wmma_"): + # gfx1250 runs the WMMA ptpc backend; other archs use the MFMA preshuffle path. + if get_gfx() == "gfx1250": from .flydsl.bpreshuffle_gemm_gfx1250 import run_gemm_a8w8_bpreshuffle_gfx1250 return run_gemm_a8w8_bpreshuffle_gfx1250( @@ -473,6 +474,8 @@ def get_GEMM_config_with_quant_type( msg = f"shape M:{M}, N:{N}, K:{K} q_dtype_w:{q_dtype_w}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned, in {tuned_file}!" if "libtype" in config: msg += f" libtype is {config['libtype']}!" + if "kernelName" in config: + msg += f" kernelName is {config['kernelName']} (kernelId {config.get('kernelId')})!" logger.info(msg) break if config is None: @@ -661,6 +664,26 @@ def gemm_a8w8_bpreshuffle( return gemm_a8w8_bpreshuffle_cktile(XQ, WQ, x_scale, w_scale, Y, splitK) elif libtype == "flydsl" and is_flydsl_available(): return gemm_a8w8_bpreshuffle_flydsl(XQ, WQ, x_scale, w_scale, Y, config) + + if get_gfx() == "gfx1250" and is_flydsl_available(): + from ..ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( + kernel_fits_shape, + kernels_list, + ) + + fits = [ki for ki in kernels_list.values() if kernel_fits_shape(ki, m, n, k)] + if fits: + want_tm = min(256, max(16, 1 << (m - 1).bit_length())) + ki = min( + fits, key=lambda x: (abs(x.tile_m - want_tm), -x.tile_n, -x.tile_k) + ) + logger.warning( + f"[gfx1250] gemm_a8w8_bpreshuffle untuned M={m}, N={n}, K={k}; " + f"falling back to flydsl kernel '{ki.name}'." + ) + return gemm_a8w8_bpreshuffle_flydsl( + XQ, WQ, x_scale, w_scale, Y, {"kernelName": ki.name} + ) try: return gemm_a8w8_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Y, 0) except RuntimeError as e: diff --git a/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py b/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py index bd9587b867..a87360b5f5 100755 --- a/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py +++ b/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py @@ -155,6 +155,34 @@ def run_gemm_flydsl(x, weight_shuffle, x_scale, w_scale, out, kernel_id): return out +def run_gemm_flydsl_gfx1250(x, weight_shuffle, x_scale, w_scale, out, kernel_id): + from aiter.ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( + kernels_list as kernels_list_flydsl_wmma, + ) + from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import ( + run_preshuffle_gemm_a8_gfx1250, + ) + + ki = kernels_list_flydsl_wmma[kernel_id] + run_preshuffle_gemm_a8_gfx1250( + x, + weight_shuffle, + x_scale, + w_scale, + out, + ki.tile_m, + ki.tile_n, + ki.tile_k, + num_buffers=ki.num_buffers, + split_k=ki.split_k, + cluster_m=ki.cluster_m, + cluster_n=ki.cluster_n, + m_warp=ki.m_warp, + n_warp=ki.n_warp, + ) + return out + + def generate_data( m, n, k, seed, dtype=dtypes.bf16, q_dtype_w=dtypes.fp8, is_asm=False, device="cuda" ): @@ -464,6 +492,10 @@ def get_flydsl_gemm_a8w8_bpreshuffle_tune_task( seed, ): gfx, cu_num, M, N, K, q_dtype_w = info_keys + + if gfx == "gfx1250": + return self._get_flydsl_tune_task_gfx1250(info_keys, seed) + q_dtype_eval = eval(q_dtype_w) if q_dtype_eval == dtypes.fp8: pass @@ -541,6 +573,63 @@ def get_flydsl_gemm_a8w8_bpreshuffle_tune_task( ) return tasks + def _get_flydsl_tune_task_gfx1250(self, info_keys, seed): + """gfx1250 WMMA ptpc tuning tasks for the FlyDSL libtype.""" + gfx, cu_num, M, N, K, q_dtype_w = info_keys + if eval(q_dtype_w) != dtypes.fp8: + print( + f"[FlyDSL][gfx1250] WMMA ptpc supports fp8 only, skipping {q_dtype_w}" + ) + return [] + if not is_flydsl_available(): + return [] + try: + from aiter.ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( + kernels_list as kernels_list_flydsl_wmma, + kernel_fits_shape as kernel_fits_shape_wmma, + ) + except ImportError: + return [] + if not kernels_list_flydsl_wmma: + return [] + gemm_keys = ["x", "weight_shuffle", "x_scale", "w_scale", "out"] + ref_keys = ["x", "weight", "x_scale", "w_scale", "bias_f32"] + tasks = [] + for i in sorted(kernels_list_flydsl_wmma.keys()): + ki = kernels_list_flydsl_wmma[i] + if not kernel_fits_shape_wmma(ki, M, N, K): + continue + info = (info_keys, i, 0, ki.name, "flydsl") + tasks.append( + ( + info, + generate_data, + (M, N, K, seed, dtypes.bf16, dtypes.fp8), + run_gemm_flydsl_gfx1250, + ( + gemm_keys, + i, + ), + { + "num_warmup": args.warmup, + "num_iters": args.iters, + }, + run_torch, + ( + ref_keys, + dtypes.bf16, + ), + {}, + None, + 1e-2, + 0.01, + None, + None, + ("out",), + ) + ) + return tasks + def tune( self, untunedf, diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index 73cbd1659a..4106469a62 100755 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -356,9 +356,18 @@ def calculate_total_valid_points(cu_count, aligned_k): def test_normal_gemm_a8w8_pertoken_quant( l_dtype, l_quantDtype, l_mnk, pad_a=128, skip_ck=False ): + is_gfx1250 = get_gfx() == "gfx1250" + if is_gfx1250 and not skip_ck: + aiter.logger.warning("gfx1250 has no CK a8w8 path; forcing skip_ck=True.") + skip_ck = True df = [] for dtype in l_dtype: for quantDtype in l_quantDtype: + if is_gfx1250 and quantDtype == dtypes.i8: + aiter.logger.warning( + "gfx1250 a8w8 only supports fp8 pertoken quant; skipping i8 shapes." + ) + continue for m, n, k in l_mnk: ret = test_gemm( dtype, m, n, k, quantDtype, pad_a=pad_a, skip_ck=skip_ck @@ -617,7 +626,8 @@ def _iter_flydsl_csv_cases(): df = test_normal_gemm_a8w8_pertoken_quant( args.dtype, args.quantDtype, args.mnk, args.pad_a ) - test_skinny_gemm_a8w8_pertoken_quant() + if get_gfx() != "gfx1250": + test_skinny_gemm_a8w8_pertoken_quant() if args.output and df is not None: os.makedirs(args.output, exist_ok=True) diff --git a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py deleted file mode 100644 index 0fdd2f6f4f..0000000000 --- a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py +++ /dev/null @@ -1,184 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -"""Correctness tests for the gfx1250 (WMMA) a8w8 bpreshuffle GEMM. - -aiter.gemm_a8w8_bpreshuffle routes its FlyDSL path to the WMMA backend -(bpreshuffle_gemm_gfx1250) by the kernelName prefix ``flydsl_bpreshuffle_wmma_``. -Semantics are the ordinary a8w8 per-token (x_scale[M]) / per-channel (w_scale[N]) -fp8 GEMM, so inputs are quantized exactly like the standard a8w8 path. Skipped off -gfx1250. -""" - -import pytest -import torch - -import aiter -from aiter.utility import dtypes -from aiter.ops.shuffle import shuffle_weight -from aiter.jit.utils.chip_info import get_gfx_runtime as get_gfx -from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import wmma_kernel_name - -pytestmark = pytest.mark.skipif( - not torch.cuda.is_available() or get_gfx() != "gfx1250", - reason="gfx1250 WMMA a8w8 bpreshuffle requires a gfx1250 device", -) - - -def _kernel_name( - tile_m, tile_n, tile_k, num_buffers, split_k=1, cluster_m=1, cluster_n=1 -): - return wmma_kernel_name( - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - num_buffers=num_buffers, - split_k=split_k, - cluster_m=cluster_m, - cluster_n=cluster_n, - ) - - -def _metrics(out, ref): - out_f, ref_f = out.float(), ref.float() - rel = (out_f - ref_f).abs().sum() / ref_f.abs().sum().clamp_min(1e-6) - cos = torch.nn.functional.cosine_similarity(out_f.flatten(), ref_f.flatten(), dim=0) - return rel.item(), cos.item() - - -def _quant(M, N, K): - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) * 2.0 - b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) * 2.0 - aq, a_scale = aiter.pertoken_quant(a, quant_dtype=dtypes.fp8) # [M, 1] - bq, b_scale = aiter.pertoken_quant(b, quant_dtype=dtypes.fp8) # [N, 1] - return aq, bq, a_scale, b_scale - - -def _ref(aq, bq, a_scale, b_scale, dtype): - a_f = aq.to(torch.float32) * a_scale.to(torch.float32) - b_f = bq.to(torch.float32) * b_scale.to(torch.float32) - return (a_f @ b_f.t()).to(dtype) - - -def _inject_tuned_config(monkeypatch, name): - import aiter.ops.gemm_op_a8w8 as gmod - - config = {"libtype": "flydsl", "splitK": 1, "kernelName": name} - monkeypatch.setattr(gmod, "get_GEMM_config_with_quant_type", lambda *a, **k: config) - - -def test_kernel_name_roundtrips(): - """Every catalogue kernelName must decode back to its config.""" - from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import parse_wmma_kernel_name - from aiter.ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( - kernels_list, - ) - - assert kernels_list, "WMMA catalogue is empty" - ki = kernels_list[0] - cfg = parse_wmma_kernel_name(ki.name) - assert cfg is not None, f"cannot parse {ki.name}" - assert (cfg["tile_m"], cfg["tile_n"], cfg["tile_k"]) == ( - ki.tile_m, - ki.tile_n, - ki.tile_k, - ) - assert cfg["num_buffers"] == ki.num_buffers and cfg["split_k"] == ki.split_k - assert cfg["cluster_m"] == ki.cluster_m and cfg["cluster_n"] == ki.cluster_n - - -@pytest.mark.parametrize( - "M,N,K", - [ - (256, 256, 256), - (512, 1024, 512), - (1, 4096, 4096), # decode: M=1 (padded internally, no NaN) - (333, 576, 1024), # ragged M, N=576 (tile_n=64 divides it) - (17, 64, 512), # tiny ragged M, small N - ], -) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -def test_gemm_a8w8_bpreshuffle_gfx1250(M, N, K, dtype, monkeypatch): - _inject_tuned_config(monkeypatch, _kernel_name(128, 64, 128, 2)) - torch.manual_seed(0) - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, dtype) - bq_prepared = shuffle_weight(bq, layout=(16, 16)) - out = aiter.gemm_a8w8_bpreshuffle(aq, bq_prepared, a_scale, b_scale, dtype=dtype) - - assert out.shape == (M, N) - assert out.dtype == dtype - rel, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low (M={M},N={N},K={K})" - assert rel < 0.05, f"rel L1={rel} too high (M={M},N={N},K={K})" - - -@pytest.mark.parametrize("num_buffers", [2, 3, 4]) -def test_num_buffers(num_buffers, monkeypatch): - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, num_buffers)) - torch.manual_seed(0) - M, N, K = 256, 256, 1024 # K/tile_k = 8 >= 4 buffers - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = aiter.gemm_a8w8_bpreshuffle( - aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 - ) - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low (num_buffers={num_buffers})" - - -@pytest.mark.parametrize("split_k", [2, 4]) -def test_split_k(split_k, monkeypatch): - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=split_k)) - torch.manual_seed(0) - M, N, K = 256, 256, 1024 - aq, bq, a_scale, b_scale = _quant(M, N, K) - bq_sh = shuffle_weight(bq, layout=(16, 16)) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = aiter.gemm_a8w8_bpreshuffle(aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16) - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low (split_k={split_k})" - - # split-k must accumulate in fp32 - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=1)) - out_sk1 = aiter.gemm_a8w8_bpreshuffle( - aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16 - ) - rel, cos = _metrics(out, out_sk1) - assert rel < 1e-3, f"split_k={split_k} drifts from sk1 (rel L1={rel})" - assert cos > 0.9999, f"split_k={split_k} drifts from sk1 (cos={cos})" - - -def test_cluster(monkeypatch): - """Workgroup cluster (cluster_m/n>1) over an evenly divisible grid.""" - _inject_tuned_config( - monkeypatch, _kernel_name(128, 128, 128, 2, cluster_m=2, cluster_n=2) - ) - torch.manual_seed(0) - M, N, K = 512, 512, 512 # grid (4, 4) divisible by cluster (2, 2) - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = aiter.gemm_a8w8_bpreshuffle( - aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 - ) - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low" - - -def test_backend_direct_writes_out(): - """The gfx1250 backend writes into the caller's Out tensor.""" - from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import ( - run_preshuffle_gemm_a8_gfx1250, - ) - - torch.manual_seed(0) - M, N, K = 512, 512, 512 - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = torch.empty(M, N, device="cuda", dtype=torch.bfloat16) - ret = run_preshuffle_gemm_a8_gfx1250( - aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, out, 128, 128, 128 - ) - assert ret.data_ptr() == out.data_ptr() - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low"