From 01d7160b625186071fc98a3914a72216e4a1376b Mon Sep 17 00:00:00 2001 From: aoli Date: Sat, 6 Jun 2026 17:05:55 +0000 Subject: [PATCH 01/19] update ptpc m oob and remove m pad --- aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py | 39 +-- ...lydsl_gemm_a8w8_bpreshuffle_wmma_common.py | 11 +- .../ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py | 85 ++++-- aiter/ops/flydsl/kernels/tdm_oob.py | 281 ++++++++++++++++++ .../test_gemm_a8w8_bpreshuffle_gfx1250.py | 43 ++- 5 files changed, 412 insertions(+), 47 deletions(-) create mode 100644 aiter/ops/flydsl/kernels/tdm_oob.py diff --git a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py index 7562f31c31..755f519574 100644 --- a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py +++ b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py @@ -6,9 +6,10 @@ aiter.gemm_a8w8_bpreshuffle routes here when its tuned kernelName starts with ``flydsl_bpreshuffle_wmma_`` (gfx1250 has no MFMA preshuffle kernel). Runs the vendored gemm_fp8fp4_gfx1250 WMMA kernel in ptpc scale mode: C = (A*sa) @ (B*sb)^T -with fp32 per-token sa[M] / per-channel sb[N] applied in the epilogue. N/K are not -padded (must divide the tile); M is padded to tile_m when ragged, since the kernel -reads a full tile_m rows per workgroup (M=1 would otherwise read past A/sa -> NaN). +with fp32 per-token sa[M] / per-channel sb[N] applied in the epilogue. N/K must +divide the tile; M may be non-tile-aligned (ragged) with no host padding — the +kernel clips A/A-scale loads and the C store to the runtime M via hardware +out-of-bounds handling (split-k predicates the atomic add per-lane on row < M). """ from __future__ import annotations @@ -125,21 +126,14 @@ def run_preshuffle_gemm_a8_gfx1250( sa = _as_1d_fp32(x_scale, M, "x_scale") sb = _as_1d_fp32(w_scale, N, "w_scale") - # M padded to tile_m when ragged (kernel reads a full tile_m rows/wg). - padded_m = ((M + tile_m - 1) // tile_m) * tile_m - if padded_m == M: - a_dev = XQ.contiguous() - sa_dev = sa - else: - a_dev = torch.zeros((padded_m, K), dtype=XQ.dtype, device=XQ.device) - a_dev[:M] = XQ - sa_dev = torch.ones((padded_m,), dtype=torch.float32, device=sa.device) - sa_dev[:M] = sa - + # Ragged M needs no host padding: the kernel clips A/A-scale loads and the C + # store to the runtime M via hardware out-of-bounds, so A and the scales pass + # through unchanged. Only N/K must divide the tile (checked above). + a_dev = XQ.contiguous() b_dev = WQ.contiguous() exe = _compile_ptpc_gemm( - M=padded_m, + M=M, N=N, K=K, data_format="fp8", @@ -157,12 +151,11 @@ def run_preshuffle_gemm_a8_gfx1250( ) if accumulate_fp32: - # fp32 atomic-accumulation scratch (zeroed; narrowed into Out below). - out_buf = torch.zeros((padded_m, N), dtype=torch.float32, device=Out.device) - elif padded_m == M: - out_buf = Out.contiguous() + # fp32 atomic-accumulation scratch: zeroed because the split-k atomic add + # accumulates into it (per-lane predicated on row < M), cast into Out below. + out_buf = torch.zeros((M, N), dtype=torch.float32, device=Out.device) else: - out_buf = torch.empty((padded_m, N), dtype=Out.dtype, device=Out.device) + out_buf = Out.contiguous() stream = _fx.Stream(torch.cuda.current_stream(device=a_dev.device)) _run_compiled( @@ -170,15 +163,15 @@ def run_preshuffle_gemm_a8_gfx1250( out_buf.view(-1), _to_uint8(a_dev), _to_uint8(b_dev), - sa_dev.contiguous().view(-1), + sa.contiguous().view(-1), sb.contiguous().view(-1), - padded_m, + M, N, stream, ) if out_buf.data_ptr() != Out.data_ptr(): - Out.copy_(out_buf[:M]) + Out.copy_(out_buf) return Out diff --git a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py index 13a4b12865..af0d3dd656 100644 --- a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py +++ b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py @@ -91,10 +91,11 @@ def _build_kernels_list(): def kernel_fits_shape(ki: WmmaKernelInstance, M: int, N: int, K: int) -> bool: """N must divide tile_n (N is never padded); K must divide split_k*tile_k, and - each split-k chunk must hold >= num_buffers K-tiles to fill the pipeline. M is - padded to tile_m, so ragged M is fine without a cluster; a cluster needs an - evenly divisible grid and only pays off for M, N >= 4096. - (LDS is bounded at build time, so it is not re-checked here.) + each split-k chunk must hold >= num_buffers K-tiles to fill the pipeline. M may + be ragged — the kernel clips loads/stores to M via hardware out-of-bounds, so no + M divisibility is required (a cluster just rounds the M-grid up and OOB-clips the + extra tiles). A cluster still needs N cluster-tile-divisible and only pays off + for M, N >= 4096. (LDS is bounded at build time, so it is not re-checked here.) """ if N % ki.tile_n != 0 or K % (ki.split_k * ki.tile_k) != 0: return False @@ -103,6 +104,6 @@ def kernel_fits_shape(ki: WmmaKernelInstance, M: int, N: int, K: int) -> bool: if ki.cluster_m > 1 or ki.cluster_n > 1: if M < 4096 or N < 4096: return False - if M % (ki.cluster_m * ki.tile_m) != 0 or N % (ki.cluster_n * ki.tile_n) != 0: + if N % (ki.cluster_n * ki.tile_n) != 0: return False return True diff --git a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py index baa59da3dd..6bc90b7839 100644 --- a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py +++ b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py @@ -13,7 +13,7 @@ import flydsl.compiler as flyc import flydsl.expr as fx from flydsl._mlir import ir -from flydsl._mlir.dialects import fly, llvm +from flydsl._mlir.dialects import fly, llvm, scf from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import ( arith, @@ -57,17 +57,36 @@ def _s_prefetch_inst_burst(num_pages: int, page_bytes: int = 4096): _llvm.inline_asm(None, [], "\n".join(lines), "", has_side_effects=True) -# compatible with no early_timeout descriptor -_TDM_HAS_EARLY_TIMEOUT = ( - "early_timeout" in inspect.signature(tdm_ops.make_tensor_descriptor_2d).parameters -) +# Feature-detect the installed flydsl's TDM descriptor builder. Older pinned +# flydsl predates these args; we apply each only when supported and otherwise +# fall back to the vendored OOB-capable builder for the non-tile-aligned-M path. +_TDM_SIG_PARAMS = inspect.signature(tdm_ops.make_tensor_descriptor_2d).parameters +_TDM_HAS_EARLY_TIMEOUT = "early_timeout" in _TDM_SIG_PARAMS +_TDM_HAS_OOB = "oob_outer_bound" in _TDM_SIG_PARAMS + +def _make_tdm_desc(*, early_timeout=False, oob_outer_bound=None, **kwargs): + """Build a TDM descriptor across flydsl versions. -def _make_tdm_desc(*, early_timeout=False, **kwargs): - """Build a TDM descriptor, applying early_timeout only when supports it.""" - if _TDM_HAS_EARLY_TIMEOUT: - kwargs["early_timeout"] = early_timeout - return tdm_ops.make_tensor_descriptor_2d(**kwargs) + ``early_timeout`` / ``oob_outer_bound`` are passed to the native builder + only when it supports them. When ``oob_outer_bound`` is needed (ragged M) + but the installed flydsl lacks it, the vendored builder in ``tdm_oob`` is + used — it carries the OOB logic so the non-tile-aligned-M GEMM path works + against the older flydsl this build pins. + """ + if _TDM_HAS_OOB: + return tdm_ops.make_tensor_descriptor_2d( + early_timeout=early_timeout, oob_outer_bound=oob_outer_bound, **kwargs + ) + if oob_outer_bound is None: + if _TDM_HAS_EARLY_TIMEOUT: + kwargs["early_timeout"] = early_timeout + return tdm_ops.make_tensor_descriptor_2d(**kwargs) + from .tdm_oob import make_tensor_descriptor_2d as _make_tensor_descriptor_2d_oob + + return _make_tensor_descriptor_2d_oob( + early_timeout=early_timeout, oob_outer_bound=oob_outer_bound, **kwargs + ) # Common constants @@ -716,6 +735,7 @@ def make_desc_a(memref, k_base): workgroup_mask=a_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, + oob_outer_bound=i32_m, ) def make_desc_b(memref, k_base): @@ -754,6 +774,7 @@ def make_desc_a_half(memref, k_base, m_half: int): lds_byte_offset=arith.index(row_start * lds_a_stride_bytes), atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, + oob_outer_bound=i32_m, ) def make_desc_b_half(memref, k_base, n_half: int): @@ -2351,12 +2372,18 @@ def epilogue_atomic_adds(final_accs, addrs): addr_idx = 0 for acc_idx, vec_base, m_off, wn in _sub_tiles: sub8 = _get_acc_sub8(final_accs, acc_idx, vec_base) - if const_expr(_bf16_out): - addr_idx += _atomic_add_acc_vec8_to_buffer(sub8, addrs[addr_idx]) - else: - addr_idx += _atomic_add_acc_vec8_to_buffer( - sub8, addrs[addr_idx : addr_idx + 2] - ) + n_slots = 1 if _bf16_out else 2 + addr_arg = ( + addrs[addr_idx] if _bf16_out else addrs[addr_idx : addr_idx + 2] + ) + # Atomics use a raw global ptr (no num_records clip), so predicate + # per-lane to skip rows >= M. + row = blk_m + warp_m_base + arith.index(m_off) + lane16 + if_op = scf.IfOp(row < m_idx, [], has_else=False) + with ir.InsertionPoint(if_op.then_block): + _atomic_add_acc_vec8_to_buffer(sub8, addr_arg) + scf.YieldOp([]) + addr_idx += n_slots def grouped_accs_to_row_major(accs_grouped): row_major = [None] * n_accs @@ -2565,6 +2592,7 @@ def _l2_prefetch(k_base): num_warps=1, lds_byte_offset=d_warp_off_sgpr, for_store=True, + oob_outer_bound=i32_m, ) # TDM descriptor lane layout: dgroup0 = [predicate, lds_addr, addr_lo, addr_hi]. @@ -3121,7 +3149,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): _ptpc_sa, _ptpc_sb = _ptpc_scale_box[0] accs = epilogue_apply_ptpc_scale(accs, _ptpc_sa, _ptpc_sb) - if const_expr(use_tdm_store): + def _emit_tdm_store(): if const_expr(d_need_epilogue_fence): _pipeline_fence(outstanding=0) rocdl.sched_barrier(0) @@ -3129,7 +3157,8 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): rocdl.s_wait_dscnt(0) tdm_ops.tensor_store_2d(d_desc) tdm_ops.tensor_wait(0) - else: + + def _emit_buffer_store(): rocdl.sched_barrier(0) if const_expr(epi_addrs_box[0] is None): epi_addrs_box[0] = epilogue_prepare_addrs() @@ -3138,6 +3167,21 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): else: epilogue_stores(accs, epi_addrs_box[0]) + if const_expr(use_tdm_store): + # Full M-tiles take the fast TDM store; the partial last M-tile + # (rows >= M) falls back to the buffer store, whose num_records clip + # drops the OOB rows. + full_tile = (blk_m + arith.index(tile_m)) <= m_idx + if_op = scf.IfOp(full_tile, [], has_else=True) + with ir.InsertionPoint(if_op.then_block): + _emit_tdm_store() + scf.YieldOp([]) + with ir.InsertionPoint(if_op.else_block): + _emit_buffer_store() + scf.YieldOp([]) + else: + _emit_buffer_store() + cache_tag = ( data_format, scale_mode, @@ -3187,6 +3231,11 @@ def launch_mxscale_gemm( gy = (i32_n + (tile_n - 1)) // tile_n gz = split_k + if const_expr(use_cluster): + # Cluster launch needs a cluster-divisible grid; the extra M-tiles + # are fully OOB (rows >= M) and the kernel clips them. + gx = ((gx + (cluster_m - 1)) // cluster_m) * cluster_m + cluster_arg = (cluster_m, cluster_n, 1) if use_cluster else None kernel_mxscale_gemm( arg_c, diff --git a/aiter/ops/flydsl/kernels/tdm_oob.py b/aiter/ops/flydsl/kernels/tdm_oob.py new file mode 100644 index 0000000000..0bf1aac7b6 --- /dev/null +++ b/aiter/ops/flydsl/kernels/tdm_oob.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Vendored OOB-capable TDM 2D descriptor builder for gfx1250. + +This is a faithful copy of ``flydsl.expr.rocdl.tdm_ops.make_tensor_descriptor_2d`` +as of the FlyDSL "add M out-of-bounds support" change, carried here so the +non-tile-aligned-M (OOB) GEMM path works against the *older* flydsl this aiter +build pins, whose ``make_tensor_descriptor_2d`` predates the ``oob_outer_bound`` +argument. + +The kernel only routes through this fallback when the installed flydsl lacks +native ``oob_outer_bound`` support (see ``_make_tdm_desc`` in +``gemm_fp8fp4_gfx1250``); when flydsl has it, the native builder is used. + +To stay robust across flydsl internal-layout changes, every low-level symbol is +sourced from the installed ``tdm_ops`` module namespace (i.e. whatever that +module successfully imported) rather than re-imported from private paths. The +only behavioural delta vs. the pinned builder is the ``oob_outer_bound`` branch +that computes a runtime ``tensor_dim1``; with ``oob_outer_bound=None`` the output +is byte-identical to the original path. +""" + +from __future__ import annotations + +import math +from typing import Tuple, Union + +from flydsl.expr.rocdl import tdm_ops as _tdm + +# Reuse whatever the installed tdm_ops bound — keeps us in lock-step with the +# pinned flydsl's lower-level primitives instead of guessing private paths. +ir = _tdm.ir +std_arith = _tdm.std_arith +llvm_dialect = _tdm.llvm_dialect +memref_dialect = _tdm.memref_dialect +arith = _tdm.arith +vector = _tdm.vector +_raw = _tdm._raw +T = _tdm.T +_ArithValue = _tdm._ArithValue +compute_warp_distribution = _tdm.compute_warp_distribution +compute_padding_encoding = _tdm.compute_padding_encoding +TDMDescriptor2D = _tdm.TDMDescriptor2D + + +def make_tensor_descriptor_2d( + global_ptr, + lds_memref, + global_offset: Tuple, + tensor_shape: Tuple[int, int], + strides: Tuple[int, int], + tile_shape: Tuple[int, int], + elem_bytes: int = 2, + pad_interval: int = 0, + pad_amount: int = 0, + num_warps: int = 1, + cache_policy: int = 0, + pred: int = 1, + workgroup_mask: Union[int, "ir.Value"] = 0, + lds_byte_offset=None, + for_store: bool = False, + atomic_barrier_enable: bool = False, + early_timeout: bool = False, + oob_outer_bound=None, +) -> "TDMDescriptor2D": + """Build a 2D TDM descriptor (vendored, OOB-capable). + + See ``flydsl.expr.rocdl.tdm_ops.make_tensor_descriptor_2d`` for the full + argument reference. ``oob_outer_bound`` is the runtime outer-dim global + extent (real M for a row-major A/C); when given, ``tensor_dim1`` is set to + the tile-start-relative remaining extent + ``max(0, oob_outer_bound - (outer_off + warp_off_outer))`` while + ``tile_dim1`` stays the full per-warp tile, so the partial last tile exceeds + the tensor bound and the hardware OOB-handles the overhang (fault-safe load, + zero-fill in LDS). Accepts a Python int or an i32/index ir.Value. ``None`` + keeps ``tensor_dim1 == tile_dim1`` (OOB off) — byte-identical to the + non-OOB path. + """ + from flydsl._mlir.dialects import fly as _fly_d + + outer_size, inner_size = tensor_shape + outer_stride, inner_stride = strides + outer_tile, inner_tile = tile_shape + outer_off, inner_off = global_offset + + # -- Warp distribution -- + warps_per_dim, block_per_warp = compute_warp_distribution( + [outer_tile, inner_tile], + num_warps, + ) + bpw_outer, bpw_inner = block_per_warp + warps_dim0 = warps_per_dim[0] + + if num_warps > 1: + # Auto-acquire SGPR wave_id via hardware register (TTMP8[29:25]). + # This keeps the entire descriptor address chain in SALU, + from flydsl.expr import rocdl as _rocdl_ext + + _wid_i32 = _rocdl_ext.wave_id() + wave_id = arith.index_cast(T.index, _wid_i32) + warp_coord_outer = wave_id % arith.index(warps_dim0) + warp_coord_inner = wave_id / arith.index(warps_dim0) + warp_off_outer = warp_coord_outer * arith.index(bpw_outer) + warp_off_inner = warp_coord_inner * arith.index(bpw_inner) + else: + warp_off_outer = arith.index(0) + warp_off_inner = arith.index(0) + + # -- Global address (byte address for descriptor) -- + glb_ptr_type = ir.Type.parse("!llvm.ptr<1>") + i64 = ir.IntegerType.get_signless(64) + a_raw = global_ptr.__extract_to_ir_values__()[0] + glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) + glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) + glb_elem_off = (outer_off + warp_off_outer) * arith.index(outer_stride) + ( + inner_off + warp_off_inner + ) * arith.index(inner_stride) + glb_byte_off = glb_elem_off * arith.index(elem_bytes) + glb_byte_off_i64 = arith.index_cast(T.i64, glb_byte_off) + glb_addr_i64 = glb_base_i64 + glb_byte_off_i64 + + # -- LDS address (byte address within shared memory) -- + lds_base_idx = _ArithValue( + memref_dialect.extract_aligned_pointer_as_index(lds_memref) + ) + # Compute padded LDS stride (elements) for the outer dim + if pad_interval > 0 and pad_amount > 0: + lds_inner_stride = inner_tile + pad_amount # padded row width + else: + lds_inner_stride = inner_tile + lds_warp_elem_off = warp_off_outer * arith.index(lds_inner_stride) + warp_off_inner + lds_warp_byte_off = lds_warp_elem_off * arith.index(elem_bytes) + lds_total_off = lds_base_idx + lds_warp_byte_off + if lds_byte_offset is not None: + lds_total_off = lds_total_off + lds_byte_offset + lds_addr_i32 = arith.index_cast(T.i32, lds_total_off) + + # ================================================================ + # GROUP0 (vector<4xi32>): pred, lds_addr, global_addr_lo/hi + # ================================================================ + g0_s0 = arith.constant(pred, type=T.i32) + g0_s1 = lds_addr_i32 + i32 = ir.IntegerType.get_signless(32) + g0_s2 = _ArithValue(std_arith.TruncIOp(i32, _raw(glb_addr_i64)).result) + hi_raw = _ArithValue(_raw(glb_addr_i64)).shrui(arith.constant(32, type=T.i64)) + g0_s3 = _ArithValue(std_arith.TruncIOp(i32, _raw(hi_raw)).result) | arith.constant( + 1 << 31, type=T.i32 + ) # type field = 2 in [31:30] + dgroup0 = vector.from_elements(T.vec(4, T.i32), [g0_s0, g0_s1, g0_s2, g0_s3]) + + # ================================================================ + # GROUP1 (vector<8xi32>): config + tensor dims + strides + tile + # ================================================================ + # Descriptor dim ordering: dim0=innermost, dim1=outermost + tdim0 = bpw_inner # innermost extent per warp + tdim1 = bpw_outer # outermost extent per warp + tile_d0 = bpw_inner # block dim0 per warp + tile_d1 = bpw_outer # block dim1 per warp + + # Padding can be applied to the LDS address when copying from memory to LDS, + # but not when copying from LDS to memory + # (there is no "de-padding" operation; padding is ignored). + if for_store and pad_interval > 0 and pad_amount > 0: + tile_d0 += pad_amount + pad_interval = 0 + pad_amount = 0 + + # stride_dim0 in descriptor = outermost stride in elements + stride0 = outer_stride + + # data_size = log2(elem_bytes) + data_size_code = int(math.log2(elem_bytes)) + + # Padding encoding + if pad_interval > 0 and pad_amount > 0: + elem_bits = elem_bytes * 8 + enc_interval, enc_amount = compute_padding_encoding( + pad_interval, pad_amount, elem_bits + ) + pad_enable = 1 + else: + enc_interval, enc_amount = 0, 0 + pad_enable = 0 + + # sgpr0: config bitfields + _abe = 1 if atomic_barrier_enable else 0 + _early_timeout = 1 if early_timeout else 0 + g1_s0_upper = ( + (data_size_code << 16) # data_size [17:16] + | (_abe << 18) # atomic_barrier_enable + | (0 << 19) # iterate_enable + | (pad_enable << 20) # pad_enable + | (_early_timeout << 21) # early_timeout + | (enc_interval << 22) # pad_interval [24:22] + | (enc_amount << 25) # pad_amount [31:25] + ) + + if isinstance(workgroup_mask, int): + g1_s0_val = (workgroup_mask & 0xFFFF) | g1_s0_upper + g1_s0 = arith.constant(g1_s0_val, type=T.i32) + else: + upper_const = arith.constant(g1_s0_upper, type=T.i32) + mask_i32 = arith.andi(workgroup_mask, arith.constant(0xFFFF, type=T.i32)) + g1_s0 = arith.ori(upper_const, mask_i32) + + # sgpr1: atomic_barrier_addr[15:0]=0 | tensor_dim0_lo[31:16] + g1_s1 = arith.constant((tdim0 & 0xFFFF) << 16, type=T.i32) + + if oob_outer_bound is None: + # Compile-time tensor_dim1 == tile extent: OOB checking off. + # sgpr2: tensor_dim0_hi[15:0] | tensor_dim1_lo[31:16] + g1_s2 = arith.constant( + ((tdim0 >> 16) & 0xFFFF) | ((tdim1 & 0xFFFF) << 16), + type=T.i32, + ) + # sgpr3: tensor_dim1_hi[15:0] | tile_dim0[31:16] + g1_s3 = arith.constant( + ((tdim1 >> 16) & 0xFFFF) | (tile_d0 << 16), + type=T.i32, + ) + else: + # Runtime tensor_dim1 = max(0, oob_outer_bound - (outer_off + warp_off_outer)), + # tile-start-relative (the descriptor's global address already includes the + # tile/warp start). tile_dim1 (sgpr4) stays the full per-warp tile, so the + # partial last tile exceeds the tensor bound and the HW OOB-handles the + # overhang. tensor_dim0 (innermost) and the tile dims stay compile-time. + if isinstance(oob_outer_bound, int): + ob_i32 = arith.constant(oob_outer_bound, type=T.i32) + else: + ob_i32 = ( + oob_outer_bound.ir_value() + if hasattr(oob_outer_bound, "ir_value") + else oob_outer_bound + ) + if not isinstance(ob_i32, ir.Value): + raise TypeError( + f"oob_outer_bound must be int or i32/index ir.Value, " + f"got {type(oob_outer_bound).__name__}" + ) + if isinstance(ob_i32.type, ir.IndexType): + ob_i32 = arith.index_cast(T.i32, ob_i32) + elif not ( + isinstance(ob_i32.type, ir.IntegerType) and ob_i32.type.width == 32 + ): + raise TypeError( + f"oob_outer_bound ir.Value must be index or i32, got {ob_i32.type}" + ) + start_i32 = arith.index_cast(T.i32, outer_off + warp_off_outer) + tdim1_rt = arith.maxsi( + arith.subi(ob_i32, start_i32), arith.constant(0, type=T.i32) + ) + c16 = arith.constant(16, type=T.i32) + c_mask16 = arith.constant(0xFFFF, type=T.i32) + # sgpr2: tensor_dim0_hi[15:0] (const) | tensor_dim1_lo[31:16] (runtime) + g1_s2 = arith.ori( + arith.constant((tdim0 >> 16) & 0xFFFF, type=T.i32), + arith.shli(arith.andi(tdim1_rt, c_mask16), c16), + ) + # sgpr3: tensor_dim1_hi[15:0] (runtime) | tile_dim0[31:16] (const) + g1_s3 = arith.ori( + arith.andi(arith.shrui(tdim1_rt, c16), c_mask16), + arith.constant(tile_d0 << 16, type=T.i32), + ) + + # sgpr4: tile_dim1[15:0] | tile_dim2[31:16]=0 (always the full per-warp tile) + g1_s4 = arith.constant(tile_d1 & 0xFFFF, type=T.i32) + + # sgpr5: tensor_dim0_stride (low 32 bits) — stride of outermost dim + g1_s5 = arith.constant(stride0 & 0xFFFFFFFF, type=T.i32) + + # sgpr6-7: for 2D, no higher-dim strides + g1_s6 = arith.constant(0, type=T.i32) + g1_s7 = arith.constant(0, type=T.i32) + + dgroup1 = vector.from_elements( + T.vec(8, T.i32), + [g1_s0, g1_s1, g1_s2, g1_s3, g1_s4, g1_s5, g1_s6, g1_s7], + ) + + return TDMDescriptor2D(dgroup0=dgroup0, dgroup1=dgroup1) diff --git a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py index 0fdd2f6f4f..71608233de 100644 --- a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py +++ b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py @@ -92,7 +92,7 @@ def test_kernel_name_roundtrips(): [ (256, 256, 256), (512, 1024, 512), - (1, 4096, 4096), # decode: M=1 (padded internally, no NaN) + (1, 4096, 4096), # decode: M=1 (ragged, kernel OOB-clips; no host pad) (333, 576, 1024), # ragged M, N=576 (tile_n=64 divides it) (17, 64, 512), # tiny ragged M, small N ], @@ -149,6 +149,47 @@ def test_split_k(split_k, monkeypatch): assert cos > 0.9999, f"split_k={split_k} drifts from sk1 (cos={cos})" +@pytest.mark.parametrize("split_k", [2, 4]) +@pytest.mark.parametrize("M", [17, 100, 257]) +def test_ragged_m_split_k(M, split_k, monkeypatch): + """Ragged M with split-k exercises the per-lane (row < M) atomic predicate.""" + _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=split_k)) + torch.manual_seed(0) + N, K = 256, 1024 # K/(split_k*tile_k) integral; chunk holds >= 2 buffers + aq, bq, a_scale, b_scale = _quant(M, N, K) + ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) + out = aiter.gemm_a8w8_bpreshuffle( + aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 + ) + assert out.shape == (M, N) + _, cos = _metrics(out, ref) + assert cos > 0.99, f"cosine={cos} too low (ragged M={M}, split_k={split_k})" + + +def test_vendored_oob_path(monkeypatch): + """Force the vendored OOB descriptor builder (the older-flydsl fallback) and + verify ragged-M correctness, even when the installed flydsl has native OOB. + + Uses a tile config no other test compiles so the cached kernel is built fresh + through the vendored path while ``_TDM_HAS_OOB`` is patched off. + """ + import aiter.ops.flydsl.kernels.gemm_fp8fp4_gfx1250 as kmod + + monkeypatch.setattr(kmod, "_TDM_HAS_OOB", False) + _inject_tuned_config(monkeypatch, _kernel_name(64, 64, 128, 2)) + torch.manual_seed(0) + M, N, K = 100, 256, 512 # ragged M -> partial last M-tile via vendored desc + aq, bq, a_scale, b_scale = _quant(M, N, K) + ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) + out = aiter.gemm_a8w8_bpreshuffle( + aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 + ) + assert out.shape == (M, N) + rel, cos = _metrics(out, ref) + assert cos > 0.99, f"cosine={cos} too low (vendored OOB path)" + assert rel < 0.05, f"rel L1={rel} too high (vendored OOB path)" + + def test_cluster(monkeypatch): """Workgroup cluster (cluster_m/n>1) over an evenly divisible grid.""" _inject_tuned_config( From b0288339794ddc4d1100a29220edee1935740081 Mon Sep 17 00:00:00 2001 From: aoli Date: Sun, 7 Jun 2026 02:57:24 +0000 Subject: [PATCH 02/19] [gfx1250] tune m_warp/n_warp for ptpc bpreshuffle gemm; fix LDS estimate - catalogue: exact per-stage LDS arena estimate (A row-pad + 16/128/1024 alignment + epilogue D buffer) so over-LDS tiles (e.g. t64x256x256_nb4 = 331776 B > 320 KiB) are filtered out instead of faulting the GPU at launch - make m_warp/n_warp first-class tuned params: kernel name carries mw{m}_nw{n} (mandatory), catalogue sweeps _WARP_COMBOS and allows tile_m=16, host dispatch and tuner thread them through; reaches the decode-winning m_warp=1 configs - compile with M=0: compile-time M is codegen-unused (runtime i32_m drives all bounds via OOB), so the kernel caches per (N,K,config) and is reused across M instead of recompiling per M - kernel_fits_shape: drop the now-unneeded cluster M-divisibility (OOB handles ragged M; grid rounds up), keep N tile/cluster divisibility - drop a redundant double .contiguous() on the scale tensors - tests: m_warp/n_warp config coverage, ragged-M split-k, full name roundtrip --- aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py | 20 +++-- ...lydsl_gemm_a8w8_bpreshuffle_wmma_common.py | 82 +++++++++++++++---- .../gemm_a8w8_bpreshuffle_wmma_tune.py | 2 + .../test_gemm_a8w8_bpreshuffle_gfx1250.py | 55 ++++++++++--- 4 files changed, 127 insertions(+), 32 deletions(-) diff --git a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py index 755f519574..debf0b188d 100644 --- a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py +++ b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py @@ -132,8 +132,11 @@ def run_preshuffle_gemm_a8_gfx1250( a_dev = XQ.contiguous() b_dev = WQ.contiguous() + # Compile-time M is unused for codegen (the kernel reads the runtime i32_m for + # all bounds/grid), so compile with M=0: the kernel is cached (lru_cache) per + # (N, K, config) and reused across every M instead of recompiling per M. exe = _compile_ptpc_gemm( - M=M, + M=0, N=N, K=K, data_format="fp8", @@ -158,13 +161,14 @@ def run_preshuffle_gemm_a8_gfx1250( out_buf = Out.contiguous() stream = _fx.Stream(torch.cuda.current_stream(device=a_dev.device)) + # sa/sb are already contiguous 1-D fp32 (see _as_1d_fp32). _run_compiled( exe, out_buf.view(-1), _to_uint8(a_dev), _to_uint8(b_dev), - sa.contiguous().view(-1), - sb.contiguous().view(-1), + sa.view(-1), + sb.view(-1), M, N, stream, @@ -175,10 +179,11 @@ def run_preshuffle_gemm_a8_gfx1250( return Out -# flydsl_bpreshuffle_wmma_t{tm}x{tn}x{tk}_nb{nb}_sk{sk}_cm{cm}_cn{cn} +# flydsl_bpreshuffle_wmma_t{tm}x{tn}x{tk}_mw{mw}_nw{nw}_nb{nb}_sk{sk}_cm{cm}_cn{cn} _KERNEL_NAME_RE = re.compile( r"^flydsl_bpreshuffle_wmma_" r"t(?P\d+)x(?P\d+)x(?P\d+)_" + r"mw(?P\d+)_nw(?P\d+)_" r"nb(?P\d+)_sk(?P\d+)_" r"cm(?P\d+)_cn(?P\d+)$" ) @@ -193,10 +198,13 @@ def wmma_kernel_name( split_k: int, cluster_m: int, cluster_n: int, + m_warp: int, + n_warp: int, ) -> str: return ( f"flydsl_bpreshuffle_wmma_t{tile_m}x{tile_n}x{tile_k}_" - f"nb{num_buffers}_sk{split_k}_cm{cluster_m}_cn{cluster_n}" + f"mw{m_warp}_nw{n_warp}_nb{num_buffers}_sk{split_k}_" + f"cm{cluster_m}_cn{cluster_n}" ) @@ -231,4 +239,6 @@ def run_gemm_a8w8_bpreshuffle_gfx1250( split_k=cfg["split_k"], cluster_m=cfg["cluster_m"], cluster_n=cfg["cluster_n"], + m_warp=cfg["m_warp"], + n_warp=cfg["n_warp"], ) diff --git a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py index af0d3dd656..645d5e844a 100644 --- a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py +++ b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py @@ -16,16 +16,29 @@ from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import wmma_kernel_name -WMMA = 16 -WARP = 2 # m_warp = n_warp = 2 -> tile_m/n must be multiples of 32 -LDS_BYTES = 320 * 1024 - -_TILE_M = (32, 64, 256) -_TILE_N = (32, 64, 256) +WMMA = 16 # WMMA M/N tile granularity +WARP = 2 # default m_warp / n_warp for WmmaKernelInstance +LDS_BYTES = 320 * 1024 # gfx1250 shared-memory capacity (== device limit) + +# Mirror the ptpc fp8 LDS layout in gemm_fp8fp4_gfx1250.compile_fp8fp4_gemm so the +# candidate filter matches the kernel's real allocation. fp8 packs 1 byte/elem +# (packed_tile_k == tile_k); ptpc keeps no scale pool in LDS. Keep in sync with +# that file's LDS_PAD_A_BYTES / LDS_PAD_D_BYTES / elem_bytes_d. +_LDS_PAD_A_BYTES = 16 +_LDS_PAD_D_BYTES = 16 +_ELEM_BYTES_D = 2 # bf16 / f16 output + +_TILE_M = (16, 32, 64, 256) +_TILE_N = (32, 64, 128, 256) _TILE_K = (128, 256) -_NUM_BUFFERS = (2, 3, 4) -_SPLIT_K = (1, 2, 4, 8, 16) -_CLUSTER = ((1, 1), (2, 2), (2, 4), (4, 2)) # cluster_m * cluster_n <= 16 +_NUM_BUFFERS = (4,) +_SPLIT_K = (1, 2, 4) +# (m_warp, n_warp): wave-specialized ptpc requires m_warp*n_warp >= 2, and +# block_threads = m_warp*n_warp*32 <= 1024 (m_warp*n_warp <= 32). m_warp=1 with a +# small tile_m wins for decode (small M); m_warp=2 serves larger M; n_warp=4 helps +# very wide N. Add (1, 4) / (4, 2) here if a shape benefits. +_WARP_COMBOS = ((1, 2), (2, 2), (2, 4)) +_CLUSTER = ((1, 1),) # cluster_m * cluster_n <= 16 @dataclass @@ -37,6 +50,8 @@ class WmmaKernelInstance: split_k: int = 1 cluster_m: int = 1 cluster_n: int = 1 + m_warp: int = WARP + n_warp: int = WARP @property def name(self) -> str: @@ -48,15 +63,50 @@ def name(self) -> str: split_k=self.split_k, cluster_m=self.cluster_m, cluster_n=self.cluster_n, + m_warp=self.m_warp, + n_warp=self.n_warp, ) -def _tile_valid(tm: int, tn: int, tk: int) -> bool: - return tm % (WARP * WMMA) == 0 and tn % (WARP * WMMA) == 0 and tk % 128 == 0 +def _tile_valid(tm: int, tn: int, tk: int, mw: int, nw: int) -> bool: + # Mirrors the kernel's warp_tile_m/n constraints: each must be a multiple of + # WMMA (16). block_threads = mw*nw*32 <= 1024 (mw*nw <= 32). + return ( + tm % (mw * WMMA) == 0 + and tn % (nw * WMMA) == 0 + and tk % 128 == 0 + and mw * nw <= 32 + ) + + +def _align_up(value: int, align: int) -> int: + return (value + align - 1) // align * align def kernel_instance_estimated_lds_bytes(ki: WmmaKernelInstance) -> int: - return (ki.tile_m * ki.tile_k + ki.tile_n * ki.tile_k) * ki.num_buffers + """LDS bytes the ptpc fp8 WMMA kernel actually allocates for ``ki``. + + Replicates the per-stage arena layout: each stage holds the A data pool + (rows padded by LDS_PAD_A_BYTES) followed by the 16-aligned B data pool (no + ptpc scale pool), the stage is 128- then 1024-aligned, and the arena is that + pitch times num_buffers. The split_k==1 epilogue also needs a TDM-store D + buffer, which can exceed the arena for small tiles, so take the max. The + estimate must be exact/conservative: an under-estimate would let an + overflowing tile through the candidate filter and fault the GPU at launch. + """ + lds_a_data = ki.tile_m * (ki.tile_k + _LDS_PAD_A_BYTES) + lds_b_data = ki.tile_n * ki.tile_k + stage_bytes = _align_up(lds_a_data, 16) + lds_b_data + stage_pitch = _align_up(_align_up(stage_bytes, 128), 1024) + arena_bytes = stage_pitch * ki.num_buffers + + if ki.split_k == 1: # split_k>1 uses the buffer/atomic store, no LDS D buffer + warp_tile_m = ki.tile_m // ki.m_warp + warp_tile_n = ki.tile_n // ki.n_warp + d_row_stride = warp_tile_n * _ELEM_BYTES_D + _LDS_PAD_D_BYTES + total_d_bytes = (ki.m_warp * ki.n_warp) * warp_tile_m * d_row_stride + return max(arena_bytes, total_d_bytes) + return arena_bytes def max_lds_bytes_for_tune() -> int: @@ -66,12 +116,12 @@ def max_lds_bytes_for_tune() -> int: def _build_kernels_list(): kl = {} idx = 0 - for nb, sk, (cm, cn), tm, tn, tk in product( - _NUM_BUFFERS, _SPLIT_K, _CLUSTER, _TILE_M, _TILE_N, _TILE_K + for nb, sk, (cm, cn), (mw, nw), tm, tn, tk in product( + _NUM_BUFFERS, _SPLIT_K, _CLUSTER, _WARP_COMBOS, _TILE_M, _TILE_N, _TILE_K ): - if not _tile_valid(tm, tn, tk): + if not _tile_valid(tm, tn, tk, mw, nw): continue - ki = WmmaKernelInstance(tm, tn, tk, nb, sk, cm, cn) + ki = WmmaKernelInstance(tm, tn, tk, nb, sk, cm, cn, mw, nw) if kernel_instance_estimated_lds_bytes(ki) > LDS_BYTES: continue kl[idx] = ki diff --git a/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py b/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py index bb415faf46..154b617cd2 100644 --- a/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py +++ b/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py @@ -78,6 +78,8 @@ def run_gemm(x, weight_shuffle, x_scale, w_scale, out, kernel_id): split_k=ki.split_k, cluster_m=ki.cluster_m, cluster_n=ki.cluster_n, + m_warp=ki.m_warp, + n_warp=ki.n_warp, ) return out diff --git a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py index 71608233de..ab7e3f6d1e 100644 --- a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py +++ b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py @@ -26,7 +26,15 @@ def _kernel_name( - tile_m, tile_n, tile_k, num_buffers, split_k=1, cluster_m=1, cluster_n=1 + tile_m, + tile_n, + tile_k, + num_buffers, + split_k=1, + cluster_m=1, + cluster_n=1, + m_warp=2, + n_warp=2, ): return wmma_kernel_name( tile_m=tile_m, @@ -36,6 +44,8 @@ def _kernel_name( split_k=split_k, cluster_m=cluster_m, cluster_n=cluster_n, + m_warp=m_warp, + n_warp=n_warp, ) @@ -75,16 +85,17 @@ def test_kernel_name_roundtrips(): ) assert kernels_list, "WMMA catalogue is empty" - ki = kernels_list[0] - cfg = parse_wmma_kernel_name(ki.name) - assert cfg is not None, f"cannot parse {ki.name}" - assert (cfg["tile_m"], cfg["tile_n"], cfg["tile_k"]) == ( - ki.tile_m, - ki.tile_n, - ki.tile_k, - ) - assert cfg["num_buffers"] == ki.num_buffers and cfg["split_k"] == ki.split_k - assert cfg["cluster_m"] == ki.cluster_m and cfg["cluster_n"] == ki.cluster_n + for ki in kernels_list.values(): + cfg = parse_wmma_kernel_name(ki.name) + assert cfg is not None, f"cannot parse {ki.name}" + assert (cfg["tile_m"], cfg["tile_n"], cfg["tile_k"]) == ( + ki.tile_m, + ki.tile_n, + ki.tile_k, + ) + assert cfg["num_buffers"] == ki.num_buffers and cfg["split_k"] == ki.split_k + assert cfg["cluster_m"] == ki.cluster_m and cfg["cluster_n"] == ki.cluster_n + assert cfg["m_warp"] == ki.m_warp and cfg["n_warp"] == ki.n_warp @pytest.mark.parametrize( @@ -166,6 +177,28 @@ def test_ragged_m_split_k(M, split_k, monkeypatch): assert cos > 0.99, f"cosine={cos} too low (ragged M={M}, split_k={split_k})" +@pytest.mark.parametrize( + "m_warp,n_warp,tile_m,tile_n", + [(1, 2, 16, 32), (1, 4, 16, 64), (2, 2, 32, 32), (2, 4, 32, 64)], +) +def test_warp_configs(m_warp, n_warp, tile_m, tile_n, monkeypatch): + """m_warp/n_warp are tunable; m_warp=1 small-tile_m configs serve decode.""" + _inject_tuned_config( + monkeypatch, + _kernel_name(tile_m, tile_n, 256, 2, m_warp=m_warp, n_warp=n_warp), + ) + torch.manual_seed(0) + M, N, K = 1, 256, 512 + aq, bq, a_scale, b_scale = _quant(M, N, K) + ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) + out = aiter.gemm_a8w8_bpreshuffle( + aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 + ) + assert out.shape == (M, N) + _, cos = _metrics(out, ref) + assert cos > 0.99, f"cosine={cos} low (mw{m_warp} nw{n_warp} t{tile_m}x{tile_n})" + + def test_vendored_oob_path(monkeypatch): """Force the vendored OOB descriptor builder (the older-flydsl fallback) and verify ragged-M correctness, even when the installed flydsl has native OOB. From 8255a27ee30cdef830d0568290ca7816d6e9f4cd Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 7 Jun 2026 04:43:31 +0000 Subject: [PATCH 03/19] ptpc: set scale buffer num_records from runtime M/N to keep OOB clipping --- .../ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py index 6bc90b7839..4f0332dd74 100644 --- a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py +++ b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py @@ -1086,8 +1086,11 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): fmtB=0, ) else: - # for compatibility, use the following no-scale wmma instead future - # accs[idx] = rocdl.wmma_f32_16x16x128_fp8_fp8(T.vec(8, T.f32), b_frag, a_frag, accs[idx]) + # PTPC-FP8 needs no per-K scaling. We emit the scaled f8f6f4 op + # with an identity E8M0 scale (0x7F = 2^0 = 1.0) for toolchain + # compatibility; it is numerically equivalent to the dedicated + # no-scale op. Future: switch to the equivalent no-scale wmma: + # accs[idx] = rocdl.wmma_f32_16x16x128_fp8_fp8(T.vec(8, T.f32), b_frag, a_frag, accs[idx]) accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( T.vec(8, T.f32), b_frag, @@ -2399,8 +2402,16 @@ def finalize_acc_layout(accs_in): def epilogue_load_ptpc_scales(): # PTPC scales: sa[M] per-token (scalar per wm), sb[N] per-channel # (8 contiguous N cols per wn). Both fp32, constant along K. - sa_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) - sb_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, max_size=False) + # The scale memrefs are dynamically shaped, so max_size=False would fall + # back to a max-sized descriptor and disable hardware OOB. Derive + # num_records from runtime M / compile-time N (fp32 = 4 bytes) so the + # partial last M-tile clips rows >= M (and cols >= N) to 0. + sa_rsrc = buffer_ops.create_buffer_resource( + arg_a_scale, num_records_bytes=m_idx * arith.index(4) + ) + sb_rsrc = buffer_ops.create_buffer_resource( + arg_b_scale, num_records_bytes=N * 4 + ) sa = [] for wm in range_constexpr(wmma_m_rep): row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 From bbbe43d18802a072afa304214bd0d904b6b49b50 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 7 Jun 2026 04:54:51 +0000 Subject: [PATCH 04/19] ptpc tune: cap warp tile at 128 to drop faulting tile_m=256/m_warp=1 --- .../flydsl_gemm_a8w8_bpreshuffle_wmma_common.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py index 645d5e844a..f088a3a412 100644 --- a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py +++ b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py @@ -19,6 +19,7 @@ WMMA = 16 # WMMA M/N tile granularity WARP = 2 # default m_warp / n_warp for WmmaKernelInstance LDS_BYTES = 320 * 1024 # gfx1250 shared-memory capacity (== device limit) +_MAX_WARP_TILE = 128 # Mirror the ptpc fp8 LDS layout in gemm_fp8fp4_gfx1250.compile_fp8fp4_gemm so the # candidate filter matches the kernel's real allocation. fp8 packs 1 byte/elem @@ -31,8 +32,8 @@ _TILE_M = (16, 32, 64, 256) _TILE_N = (32, 64, 128, 256) _TILE_K = (128, 256) -_NUM_BUFFERS = (4,) -_SPLIT_K = (1, 2, 4) +_NUM_BUFFERS = (3, 4) +_SPLIT_K = (1, 2, 4, 8) # (m_warp, n_warp): wave-specialized ptpc requires m_warp*n_warp >= 2, and # block_threads = m_warp*n_warp*32 <= 1024 (m_warp*n_warp <= 32). m_warp=1 with a # small tile_m wins for decode (small M); m_warp=2 serves larger M; n_warp=4 helps @@ -70,10 +71,13 @@ def name(self) -> str: def _tile_valid(tm: int, tn: int, tk: int, mw: int, nw: int) -> bool: # Mirrors the kernel's warp_tile_m/n constraints: each must be a multiple of - # WMMA (16). block_threads = mw*nw*32 <= 1024 (mw*nw <= 32). + # WMMA (16) and at most _MAX_WARP_TILE (per-wave VGPR/accumulator budget; see + # the _MAX_WARP_TILE note). block_threads = mw*nw*32 <= 1024 (mw*nw <= 32). return ( tm % (mw * WMMA) == 0 and tn % (nw * WMMA) == 0 + and tm // mw <= _MAX_WARP_TILE + and tn // nw <= _MAX_WARP_TILE and tk % 128 == 0 and mw * nw <= 32 ) From 848122422cf0791a5126fc8282b6fb5d6cffbca4 Mon Sep 17 00:00:00 2001 From: aoli Date: Sun, 7 Jun 2026 05:05:15 +0000 Subject: [PATCH 05/19] [gfx1250] flydsl ptpc bpreshuffle gemm: arch dispatch and review fixes --- aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py | 36 ++++++++++++------- ...lydsl_gemm_a8w8_bpreshuffle_wmma_common.py | 11 ------ .../ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py | 5 ++- aiter/ops/flydsl/kernels/tdm_oob.py | 1 - aiter/ops/gemm_op_a8w8.py | 3 +- 5 files changed, 29 insertions(+), 27 deletions(-) diff --git a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py index debf0b188d..d0d1022857 100644 --- a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py +++ b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py @@ -28,7 +28,6 @@ _SUPPORTED_NUM_BUFFERS = (2, 3, 4) _OUT_DTYPE_NAME = {torch.bfloat16: "bf16", torch.float16: "f16"} - def _lazy_import(): global _compile_ptpc_gemm, _run_compiled, _fx if _compile_ptpc_gemm is not None: @@ -117,24 +116,36 @@ def run_preshuffle_gemm_a8_gfx1250( accumulate_fp32 = split_k > 1 kernel_out_dtype = "f32" if accumulate_fp32 else out_dtype - # Pipeline depth needs >= 1 K tile per buffer (per split-k chunk). - num_k_tiles = (K // split_k) // tile_k - nb = max(2, min(int(num_buffers), num_k_tiles)) + # Validate (tuned names always pass); fail loudly rather than silently clamp. + nb = int(num_buffers) if nb not in _SUPPORTED_NUM_BUFFERS: - nb = max(b for b in _SUPPORTED_NUM_BUFFERS if b <= nb) + raise RuntimeError( + f"[FlyDSL gfx1250] num_buffers must be one of {_SUPPORTED_NUM_BUFFERS}, " + f"got {nb}" + ) + if K % (split_k * tile_k) != 0: + raise RuntimeError( + f"[FlyDSL gfx1250] K={K} must be divisible by split_k*tile_k=" + f"{split_k}*{tile_k}={split_k * tile_k}" + ) + # Each split-k chunk must hold >= num_buffers K-tiles to fill the pipeline. + num_k_tiles = (K // split_k) // tile_k + if num_k_tiles < nb: + raise RuntimeError( + f"[FlyDSL gfx1250] {nb}-buffer pipeline needs >= {nb} K-tiles per " + f"split-k chunk, got {num_k_tiles} (K={K}, split_k={split_k}, tile_k={tile_k})" + ) sa = _as_1d_fp32(x_scale, M, "x_scale") sb = _as_1d_fp32(w_scale, N, "w_scale") - # Ragged M needs no host padding: the kernel clips A/A-scale loads and the C - # store to the runtime M via hardware out-of-bounds, so A and the scales pass - # through unchanged. Only N/K must divide the tile (checked above). + # Ragged M needs no host padding: the kernel clips A/scale loads and the C + # store to runtime M via hardware OOB. a_dev = XQ.contiguous() b_dev = WQ.contiguous() - # Compile-time M is unused for codegen (the kernel reads the runtime i32_m for - # all bounds/grid), so compile with M=0: the kernel is cached (lru_cache) per - # (N, K, config) and reused across every M instead of recompiling per M. + # M is unused at compile time (runtime i32_m drives everything), so compile with + # M=0 to cache one kernel per (N, K, config) and reuse it across all M. exe = _compile_ptpc_gemm( M=0, N=N, @@ -154,8 +165,7 @@ def run_preshuffle_gemm_a8_gfx1250( ) if accumulate_fp32: - # fp32 atomic-accumulation scratch: zeroed because the split-k atomic add - # accumulates into it (per-lane predicated on row < M), cast into Out below. + # fp32 split-k atomic-accumulation scratch (zeroed; cast into Out below). out_buf = torch.zeros((M, N), dtype=torch.float32, device=Out.device) else: out_buf = Out.contiguous() diff --git a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py index f088a3a412..2aaec64fa4 100644 --- a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py +++ b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py @@ -113,10 +113,6 @@ def kernel_instance_estimated_lds_bytes(ki: WmmaKernelInstance) -> int: return arena_bytes -def max_lds_bytes_for_tune() -> int: - return LDS_BYTES - - def _build_kernels_list(): kl = {} idx = 0 @@ -135,13 +131,6 @@ def _build_kernels_list(): kernels_list: dict[int, WmmaKernelInstance] = _build_kernels_list() -default_kernels_dict = { - (-1): WmmaKernelInstance(128, 128, 128, 2), - (-2): WmmaKernelInstance(32, 64, 128, 2), - (-3): WmmaKernelInstance(64, 128, 128, 2), - (-4): WmmaKernelInstance(128, 256, 128, 2), -} - def kernel_fits_shape(ki: WmmaKernelInstance, M: int, N: int, K: int) -> bool: """N must divide tile_n (N is never padded); K must divide split_k*tile_k, and diff --git a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py index 4f0332dd74..f1efd210b3 100644 --- a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py +++ b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py @@ -75,13 +75,16 @@ def _make_tdm_desc(*, early_timeout=False, oob_outer_bound=None, **kwargs): against the older flydsl this build pins. """ if _TDM_HAS_OOB: + if _TDM_HAS_EARLY_TIMEOUT: + kwargs["early_timeout"] = early_timeout return tdm_ops.make_tensor_descriptor_2d( - early_timeout=early_timeout, oob_outer_bound=oob_outer_bound, **kwargs + oob_outer_bound=oob_outer_bound, **kwargs ) if oob_outer_bound is None: if _TDM_HAS_EARLY_TIMEOUT: kwargs["early_timeout"] = early_timeout return tdm_ops.make_tensor_descriptor_2d(**kwargs) + # Vendored builder always supports early_timeout (pass it explicitly). from .tdm_oob import make_tensor_descriptor_2d as _make_tensor_descriptor_2d_oob return _make_tensor_descriptor_2d_oob( diff --git a/aiter/ops/flydsl/kernels/tdm_oob.py b/aiter/ops/flydsl/kernels/tdm_oob.py index 0bf1aac7b6..5201114d9f 100644 --- a/aiter/ops/flydsl/kernels/tdm_oob.py +++ b/aiter/ops/flydsl/kernels/tdm_oob.py @@ -79,7 +79,6 @@ def make_tensor_descriptor_2d( """ from flydsl._mlir.dialects import fly as _fly_d - outer_size, inner_size = tensor_shape outer_stride, inner_stride = strides outer_tile, inner_tile = tile_shape outer_off, inner_off = global_offset diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 3f9627dcad..9525a64df0 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -128,7 +128,8 @@ def gemm_a8w8_bpreshuffle_flydsl( config: dict, ) -> Tensor: kernel_name = str(config.get("kernelName", "")) - if kernel_name.startswith("flydsl_bpreshuffle_wmma_"): + # gfx1250 runs the WMMA ptpc backend; other archs use the MFMA preshuffle path. + if get_gfx() == "gfx1250": from .flydsl.bpreshuffle_gemm_gfx1250 import run_gemm_a8w8_bpreshuffle_gfx1250 return run_gemm_a8w8_bpreshuffle_gfx1250( From f678a2dad2b99acb905e2ca537f2c7ddb3c2974c Mon Sep 17 00:00:00 2001 From: aoli Date: Sun, 7 Jun 2026 05:10:01 +0000 Subject: [PATCH 06/19] fix ci format --- aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py index d0d1022857..d86774ce50 100644 --- a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py +++ b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py @@ -28,6 +28,7 @@ _SUPPORTED_NUM_BUFFERS = (2, 3, 4) _OUT_DTYPE_NAME = {torch.bfloat16: "bf16", torch.float16: "f16"} + def _lazy_import(): global _compile_ptpc_gemm, _run_compiled, _fx if _compile_ptpc_gemm is not None: From b4dfd49bc9afa17f44e9591797abb106639f6b45 Mon Sep 17 00:00:00 2001 From: aoli Date: Sun, 7 Jun 2026 06:18:21 +0000 Subject: [PATCH 07/19] add gfx1250 gemm shape0 m=1/2/4 tuned --- aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv index 9b9f652a22..1267ca8258 100644 --- a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv @@ -549,3 +549,6 @@ gfx950,256,4096,57344,8192,torch.float8_e4m3fn,flydsl,979,0,1486.3101,flydsl_bpr gfx950,256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2638.66,506.16,0.0 gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2609.01,420.85,0.0 gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0 +gfx1250,256,1,2112,7168,torch.float8_e4m3fn,flydsl,257,0,8.8793,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,3.41,1706.24,0.0066 +gfx1250,256,2,2112,7168,torch.float8_e4m3fn,flydsl,489,0,8.5611,flydsl_bpreshuffle_wmma_t32x32x128_mw2_nw2_nb4_sk8_cm1_cn1,7.07,1770.99,0.0078 +gfx1250,256,4,2112,7168,torch.float8_e4m3fn,flydsl,257,0,9.014,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,13.44,1684.53,0.0094 From 8c7891e8df87d3dc20ea9e9bbd1031f0c36692d0 Mon Sep 17 00:00:00 2001 From: aoli Date: Sun, 7 Jun 2026 07:24:23 +0000 Subject: [PATCH 08/19] change gfx1250 tuned knob --- aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv | 5 ++--- .../flydsl_gemm_a8w8_bpreshuffle_wmma_common.py | 16 +++++----------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv index 1267ca8258..7c944b6b1a 100644 --- a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv @@ -549,6 +549,5 @@ gfx950,256,4096,57344,8192,torch.float8_e4m3fn,flydsl,979,0,1486.3101,flydsl_bpr gfx950,256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2638.66,506.16,0.0 gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2609.01,420.85,0.0 gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0 -gfx1250,256,1,2112,7168,torch.float8_e4m3fn,flydsl,257,0,8.8793,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,3.41,1706.24,0.0066 -gfx1250,256,2,2112,7168,torch.float8_e4m3fn,flydsl,489,0,8.5611,flydsl_bpreshuffle_wmma_t32x32x128_mw2_nw2_nb4_sk8_cm1_cn1,7.07,1770.99,0.0078 -gfx1250,256,4,2112,7168,torch.float8_e4m3fn,flydsl,257,0,9.014,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,13.44,1684.53,0.0094 +gfx1250,256,1,2112,7168,torch.float8_e4m3fn,flydsl,261,0,8.8799,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,3.41,1706.12,0.0066 +gfx1250,256,2,2112,7168,torch.float8_e4m3fn,flydsl,381,0,8.2149,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk4_cm1_cn1,7.37,1845.62,0.0078 diff --git a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py index 2aaec64fa4..45946f6886 100644 --- a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py +++ b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py @@ -6,7 +6,8 @@ gfx1250 has no MFMA preshuffle kernel; it runs the vendored gemm_fp8fp4_gfx1250 WMMA kernel (ptpc) via ``bpreshuffle_gemm_gfx1250``. This is the WMMA counterpart of ``flydsl_gemm_a8w8_bpreshuffle_common`` (which serves gfx942/gfx950 MFMA), with -its own perf knobs — num_buffers, split_k, cluster — and kernelName format. +its own perf knobs (num_buffers and split_k) and kernelName format. Cluster is +kept in the candidate schema/name but fixed to (1, 1). """ from __future__ import annotations @@ -34,12 +35,8 @@ _TILE_K = (128, 256) _NUM_BUFFERS = (3, 4) _SPLIT_K = (1, 2, 4, 8) -# (m_warp, n_warp): wave-specialized ptpc requires m_warp*n_warp >= 2, and -# block_threads = m_warp*n_warp*32 <= 1024 (m_warp*n_warp <= 32). m_warp=1 with a -# small tile_m wins for decode (small M); m_warp=2 serves larger M; n_warp=4 helps -# very wide N. Add (1, 4) / (4, 2) here if a shape benefits. -_WARP_COMBOS = ((1, 2), (2, 2), (2, 4)) -_CLUSTER = ((1, 1),) # cluster_m * cluster_n <= 16 +_WARP_COMBOS = ((1, 2), (2, 2), (1, 4)) +_CLUSTER = ((1, 1),) # Keep cluster in candidates/name. @dataclass @@ -137,16 +134,13 @@ def kernel_fits_shape(ki: WmmaKernelInstance, M: int, N: int, K: int) -> bool: each split-k chunk must hold >= num_buffers K-tiles to fill the pipeline. M may be ragged — the kernel clips loads/stores to M via hardware out-of-bounds, so no M divisibility is required (a cluster just rounds the M-grid up and OOB-clips the - extra tiles). A cluster still needs N cluster-tile-divisible and only pays off - for M, N >= 4096. (LDS is bounded at build time, so it is not re-checked here.) + extra tiles). A cluster still needs N cluster-tile-divisible. """ if N % ki.tile_n != 0 or K % (ki.split_k * ki.tile_k) != 0: return False if (K // ki.split_k) // ki.tile_k < ki.num_buffers: return False if ki.cluster_m > 1 or ki.cluster_n > 1: - if M < 4096 or N < 4096: - return False if N % (ki.cluster_n * ki.tile_n) != 0: return False return True From b6a1ff3dc7c53f2ddbb51086dfe8d31d7fa6b2b1 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 7 Jun 2026 12:05:56 +0000 Subject: [PATCH 09/19] [gfx1250] flydsl ptpc gemm: support strided A/C (runtime lda/ldc), pass 2-D tensors --- aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv | 29 +++ aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py | 61 +++--- ...lydsl_gemm_a8w8_bpreshuffle_wmma_common.py | 86 +++++---- .../ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py | 65 ++++--- aiter/ops/flydsl/kernels/tdm_oob.py | 37 +++- aiter/ops/flydsl/utils.py | 2 + .../test_gemm_a8w8_bpreshuffle_gfx1250.py | 182 +++++++++++++----- 7 files changed, 310 insertions(+), 152 deletions(-) diff --git a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv index 7c944b6b1a..1f9e989ca4 100644 --- a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv @@ -551,3 +551,32 @@ gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpr gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0 gfx1250,256,1,2112,7168,torch.float8_e4m3fn,flydsl,261,0,8.8799,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,3.41,1706.12,0.0066 gfx1250,256,2,2112,7168,torch.float8_e4m3fn,flydsl,381,0,8.2149,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk4_cm1_cn1,7.37,1845.62,0.0078 +gfx1250,256,4,2112,7168,torch.float8_e4m3fn,flydsl,294,0,28.6242,flydsl_bpreshuffle_wmma_t64x64x256_mw2_nw2_nb4_sk1_cm1_cn1,4.23,530.47,0.0099 +gfx1250,256,8,2112,7168,torch.float8_e4m3fn,flydsl,303,0,27.762,flydsl_bpreshuffle_wmma_t16x64x128_mw1_nw4_nb4_sk1_cm1_cn1,8.72,548.59,0.0102 +gfx1250,256,16,2112,7168,torch.float8_e4m3fn,flydsl,25,0,30.8522,flydsl_bpreshuffle_wmma_t32x32x256_mw2_nw2_nb3_sk1_cm1_cn1,15.7,496.6,0.011 +gfx1250,256,32,2112,7168,torch.float8_e4m3fn,flydsl,286,0,28.9229,flydsl_bpreshuffle_wmma_t32x64x256_mw2_nw2_nb4_sk1_cm1_cn1,33.5,536.02,0.0105 +gfx1250,256,64,2112,7168,torch.float8_e4m3fn,flydsl,269,0,31.7564,flydsl_bpreshuffle_wmma_t32x32x256_mw1_nw2_nb4_sk1_cm1_cn1,61.02,499.68,0.0098 +gfx1250,256,1024,7168,16384,torch.float8_e4m3fn,flydsl,500,0,92.0588,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,2612.66,1617.42,0.0237 +gfx1250,256,2048,7168,16384,torch.float8_e4m3fn,flydsl,302,0,172.4662,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2789.16,1045.74,0.0237 +gfx1250,256,4096,7168,16384,torch.float8_e4m3fn,flydsl,302,0,327.0773,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2941.42,743.77,0.0236 +gfx1250,256,6144,7168,16384,torch.float8_e4m3fn,flydsl,302,0,475.4687,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3035.13,643.96,0.0237 +gfx1250,256,8192,7168,16384,torch.float8_e4m3fn,flydsl,302,0,640.4273,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3004.47,576.33,0.0237 +gfx1250,256,10240,7168,16384,torch.float8_e4m3fn,flydsl,302,0,806.3905,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2982.65,535.74,0.0237 +gfx1250,256,20480,7168,16384,torch.float8_e4m3fn,flydsl,302,0,1652.4551,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2911.04,451.8,0.0237 +gfx1250,256,65536,7168,16384,torch.float8_e4m3fn,flydsl,302,0,5987.8537,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2570.73,355.84,0.0237 +gfx1250,256,1024,24576,1536,torch.float8_e4m3fn,flydsl,302,0,64.4699,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1199.16,1390.62,0.0 +gfx1250,256,2048,24576,1536,torch.float8_e4m3fn,flydsl,302,0,90.4895,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1708.69,1564.36,0.0 +gfx1250,256,4096,24576,1536,torch.float8_e4m3fn,flydsl,302,0,146.1362,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2116.09,1679.03,0.0 +gfx1250,256,6144,24576,1536,torch.float8_e4m3fn,flydsl,302,0,197.9044,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2343.84,1764.37,0.0 +gfx1250,256,8192,24576,1536,torch.float8_e4m3fn,flydsl,302,0,257.4906,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2401.93,1759.23,0.0 +gfx1250,256,10240,24576,1536,torch.float8_e4m3fn,flydsl,302,0,316.7838,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2440.45,1757.65,0.0 +gfx1250,256,20480,24576,1536,torch.float8_e4m3fn,flydsl,302,0,607.1774,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2546.52,1771.87,0.0 +gfx1250,256,65536,24576,1536,torch.float8_e4m3fn,flydsl,302,0,2046.0765,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2418.19,1641.99,0.0 +gfx1250,256,1024,32768,512,torch.float8_e4m3fn,flydsl,302,0,50.7041,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,677.65,1664.76,0.0 +gfx1250,256,2048,32768,512,torch.float8_e4m3fn,flydsl,302,0,54.9014,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1251.69,2769.39,0.0 +gfx1250,256,4096,32768,512,torch.float8_e4m3fn,flydsl,302,0,88.9453,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1545.21,3230.19,0.0 +gfx1250,256,6144,32768,512,torch.float8_e4m3fn,flydsl,302,0,119.3873,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1726.8,3539.54,0.0 +gfx1250,256,8192,32768,512,torch.float8_e4m3fn,flydsl,302,0,150.5706,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1825.57,3704.86,0.0 +gfx1250,256,10240,32768,512,torch.float8_e4m3fn,flydsl,302,0,187.2761,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1834.71,3701.0,0.0 +gfx1250,256,20480,32768,512,torch.float8_e4m3fn,flydsl,302,0,393.5255,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1746.25,3479.93,0.0 +gfx1250,256,1024,2112,7168,torch.float8_e4m3fn,flydsl,398,0,68.7538,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,450.95,389.86,0.0096 diff --git a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py index d86774ce50..d6cc0ae675 100644 --- a/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py +++ b/aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py @@ -3,13 +3,15 @@ """gfx1250 (WMMA) backend for the FlyDSL a8w8 bpreshuffle GEMM. -aiter.gemm_a8w8_bpreshuffle routes here when its tuned kernelName starts with -``flydsl_bpreshuffle_wmma_`` (gfx1250 has no MFMA preshuffle kernel). Runs the -vendored gemm_fp8fp4_gfx1250 WMMA kernel in ptpc scale mode: C = (A*sa) @ (B*sb)^T -with fp32 per-token sa[M] / per-channel sb[N] applied in the epilogue. N/K must -divide the tile; M may be non-tile-aligned (ragged) with no host padding — the -kernel clips A/A-scale loads and the C store to the runtime M via hardware -out-of-bounds handling (split-k predicates the atomic add per-lane on row < M). +aiter.gemm_a8w8_bpreshuffle's FlyDSL path runs here on gfx1250 (no MFMA preshuffle +kernel); the tuned kernelName (prefix ``flydsl_bpreshuffle_wmma_``) encodes the tile +config. Computes C = (A*sa) @ (B*sb)^T via the vendored gemm_fp8fp4_gfx1250 WMMA +kernel, fp32 per-token sa[M] / per-channel sb[N] applied in the epilogue. + +N/K must divide the tile; M may be ragged (no host padding) — the kernel clips +loads/stores to the runtime M via hardware OOB, predicating split-k's atomic add +per-lane on row < M. A/C may be strided (lda/ldc passed at runtime, no copy) when +the inner dim is unit-stride; B is preshuffled into its own contiguous buffer. """ from __future__ import annotations @@ -27,6 +29,7 @@ _WMMA_K = 128 _SUPPORTED_NUM_BUFFERS = (2, 3, 4) _OUT_DTYPE_NAME = {torch.bfloat16: "bf16", torch.float16: "f16"} +_MAX_SPLIT_K = 4 def _lazy_import(): @@ -55,10 +58,6 @@ def _as_1d_fp32(scale: Tensor, length: int, name: str) -> Tensor: return flat.contiguous() -def _to_uint8(t: Tensor) -> Tensor: - return t.contiguous().view(torch.uint8).view(-1) - - def run_preshuffle_gemm_a8_gfx1250( XQ: Tensor, WQ: Tensor, @@ -114,8 +113,11 @@ def run_preshuffle_gemm_a8_gfx1250( cluster_m = max(1, int(cluster_m)) cluster_n = max(1, int(cluster_n)) - accumulate_fp32 = split_k > 1 - kernel_out_dtype = "f32" if accumulate_fp32 else out_dtype + if split_k > _MAX_SPLIT_K: + raise RuntimeError( + f"[FlyDSL gfx1250] split_k={split_k} exceeds the bf16/f16 atomic-add " + f"precision cap of {_MAX_SPLIT_K}" + ) # Validate (tuned names always pass); fail loudly rather than silently clamp. nb = int(num_buffers) @@ -140,15 +142,7 @@ def run_preshuffle_gemm_a8_gfx1250( sa = _as_1d_fp32(x_scale, M, "x_scale") sb = _as_1d_fp32(w_scale, N, "w_scale") - # Ragged M needs no host padding: the kernel clips A/scale loads and the C - # store to runtime M via hardware OOB. - a_dev = XQ.contiguous() - b_dev = WQ.contiguous() - - # M is unused at compile time (runtime i32_m drives everything), so compile with - # M=0 to cache one kernel per (N, K, config) and reuse it across all M. exe = _compile_ptpc_gemm( - M=0, N=N, K=K, data_format="fp8", @@ -161,32 +155,29 @@ def run_preshuffle_gemm_a8_gfx1250( waves_per_eu=(None if waves_per_eu <= 0 else waves_per_eu), cluster_m=cluster_m, cluster_n=cluster_n, - out_dtype=kernel_out_dtype, + out_dtype=out_dtype, split_k=split_k, ) - if accumulate_fp32: - # fp32 split-k atomic-accumulation scratch (zeroed; cast into Out below). - out_buf = torch.zeros((M, N), dtype=torch.float32, device=Out.device) - else: - out_buf = Out.contiguous() + lda = XQ.stride(0) + ldc = Out.stride(0) + if split_k > 1: + Out.zero_() # split-k atomic-accumulates into Out - stream = _fx.Stream(torch.cuda.current_stream(device=a_dev.device)) - # sa/sb are already contiguous 1-D fp32 (see _as_1d_fp32). + stream = _fx.Stream(torch.cuda.current_stream(device=XQ.device)) _run_compiled( exe, - out_buf.view(-1), - _to_uint8(a_dev), - _to_uint8(b_dev), + Out, + XQ.view(torch.uint8), + WQ.view(torch.uint8), sa.view(-1), sb.view(-1), M, N, + lda, + ldc, stream, ) - - if out_buf.data_ptr() != Out.data_ptr(): - Out.copy_(out_buf) return Out diff --git a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py index 45946f6886..ea2f7e5202 100644 --- a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py +++ b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py @@ -13,30 +13,42 @@ from __future__ import annotations from dataclasses import dataclass -from itertools import product from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import wmma_kernel_name +from aiter.ops.flydsl.utils import get_shared_memory_per_block WMMA = 16 # WMMA M/N tile granularity WARP = 2 # default m_warp / n_warp for WmmaKernelInstance -LDS_BYTES = 320 * 1024 # gfx1250 shared-memory capacity (== device limit) +LDS_BYTES = get_shared_memory_per_block(fallback_gfx="gfx1250") _MAX_WARP_TILE = 128 -# Mirror the ptpc fp8 LDS layout in gemm_fp8fp4_gfx1250.compile_fp8fp4_gemm so the -# candidate filter matches the kernel's real allocation. fp8 packs 1 byte/elem -# (packed_tile_k == tile_k); ptpc keeps no scale pool in LDS. Keep in sync with -# that file's LDS_PAD_A_BYTES / LDS_PAD_D_BYTES / elem_bytes_d. +# Mirror the ptpc fp8 LDS layout _LDS_PAD_A_BYTES = 16 _LDS_PAD_D_BYTES = 16 _ELEM_BYTES_D = 2 # bf16 / f16 output -_TILE_M = (16, 32, 64, 256) -_TILE_N = (32, 64, 128, 256) -_TILE_K = (128, 256) -_NUM_BUFFERS = (3, 4) -_SPLIT_K = (1, 2, 4, 8) -_WARP_COMBOS = ((1, 2), (2, 2), (1, 4)) -_CLUSTER = ((1, 1),) # Keep cluster in candidates/name. +# Columns: (tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers), grouped by M regime +# fmt: off +_CURATED_INSTANCES = ( + # small M (decode / token-gen): thin tile_m, wide tile_n, deep tile_k + ( 16, 64, 256, 1, 2, 4), ( 16, 64, 512, 1, 2, 4), ( 16, 96, 256, 1, 2, 4), + ( 16, 128, 256, 1, 2, 4), ( 16, 128, 512, 1, 2, 4), ( 16, 192, 256, 1, 2, 4), + ( 16, 256, 256, 1, 4, 4), ( 16, 512, 128, 1, 4, 4), + # M=32 + ( 32, 64, 256, 2, 2, 4), ( 32, 64, 512, 2, 2, 4), ( 32, 128, 256, 2, 2, 4), + ( 32, 192, 256, 2, 2, 4), ( 32, 256, 256, 2, 4, 4), + # M=64 + ( 64, 64, 256, 2, 2, 4), ( 64, 128, 128, 2, 2, 4), ( 64, 192, 128, 2, 2, 4), + ( 64, 256, 128, 2, 4, 4), ( 64, 512, 128, 1, 4, 3), + # M=128 + (128, 128, 128, 2, 2, 4), (128, 192, 128, 2, 2, 4), (128, 256, 128, 2, 4, 4), + (128, 512, 128, 2, 4, 3), + # large M (compute bound): big square tiles, shallow tile_k + (256, 64, 128, 2, 1, 4), (256, 128, 128, 2, 2, 4), (256, 192, 128, 2, 2, 4), + (256, 256, 128, 2, 2, 4), (256, 256, 128, 4, 4, 3), (256, 512, 128, 2, 4, 3), +) +# fmt: on +_SPLIT_K = (1, 2, 4) @dataclass @@ -67,9 +79,8 @@ def name(self) -> str: def _tile_valid(tm: int, tn: int, tk: int, mw: int, nw: int) -> bool: - # Mirrors the kernel's warp_tile_m/n constraints: each must be a multiple of - # WMMA (16) and at most _MAX_WARP_TILE (per-wave VGPR/accumulator budget; see - # the _MAX_WARP_TILE note). block_threads = mw*nw*32 <= 1024 (mw*nw <= 32). + # Each warp tile must be a multiple of WMMA (16) and <= _MAX_WARP_TILE (VGPR + # budget); tk a multiple of 128; block_threads = mw*nw*32 <= 1024. return ( tm % (mw * WMMA) == 0 and tn % (nw * WMMA) == 0 @@ -85,15 +96,12 @@ def _align_up(value: int, align: int) -> int: def kernel_instance_estimated_lds_bytes(ki: WmmaKernelInstance) -> int: - """LDS bytes the ptpc fp8 WMMA kernel actually allocates for ``ki``. - - Replicates the per-stage arena layout: each stage holds the A data pool - (rows padded by LDS_PAD_A_BYTES) followed by the 16-aligned B data pool (no - ptpc scale pool), the stage is 128- then 1024-aligned, and the arena is that - pitch times num_buffers. The split_k==1 epilogue also needs a TDM-store D - buffer, which can exceed the arena for small tiles, so take the max. The - estimate must be exact/conservative: an under-estimate would let an - overflowing tile through the candidate filter and fault the GPU at launch. + """LDS bytes the ptpc fp8 WMMA kernel allocates for ``ki`` (must not under-estimate: + an overflowing tile would pass the filter and fault at launch). + + Per-stage arena: A pool (rows padded by LDS_PAD_A_BYTES) + 16-aligned B pool, the + stage 128- then 1024-aligned, times num_buffers. The split_k==1 epilogue also needs + a TDM-store D buffer that can exceed the arena for small tiles, so take the max. """ lds_a_data = ki.tile_m * (ki.tile_k + _LDS_PAD_A_BYTES) lds_b_data = ki.tile_n * ki.tile_k @@ -113,16 +121,16 @@ def kernel_instance_estimated_lds_bytes(ki: WmmaKernelInstance) -> int: def _build_kernels_list(): kl = {} idx = 0 - for nb, sk, (cm, cn), (mw, nw), tm, tn, tk in product( - _NUM_BUFFERS, _SPLIT_K, _CLUSTER, _WARP_COMBOS, _TILE_M, _TILE_N, _TILE_K - ): - if not _tile_valid(tm, tn, tk, mw, nw): - continue - ki = WmmaKernelInstance(tm, tn, tk, nb, sk, cm, cn, mw, nw) - if kernel_instance_estimated_lds_bytes(ki) > LDS_BYTES: - continue - kl[idx] = ki - idx += 1 + for tm, tn, tk, mw, nw, nb in _CURATED_INSTANCES: + assert _tile_valid( # an invalid curated entry is a typo -- fail loudly + tm, tn, tk, mw, nw + ), f"invalid curated instance: tile=({tm},{tn},{tk}) warp=({mw},{nw})" + for sk in _SPLIT_K: + ki = WmmaKernelInstance(tm, tn, tk, nb, sk, 1, 1, mw, nw) + if kernel_instance_estimated_lds_bytes(ki) > LDS_BYTES: + continue + kl[idx] = ki + idx += 1 return kl @@ -130,11 +138,9 @@ def _build_kernels_list(): def kernel_fits_shape(ki: WmmaKernelInstance, M: int, N: int, K: int) -> bool: - """N must divide tile_n (N is never padded); K must divide split_k*tile_k, and - each split-k chunk must hold >= num_buffers K-tiles to fill the pipeline. M may - be ragged — the kernel clips loads/stores to M via hardware out-of-bounds, so no - M divisibility is required (a cluster just rounds the M-grid up and OOB-clips the - extra tiles). A cluster still needs N cluster-tile-divisible. + """N must divide tile_n; K must divide split_k*tile_k with >= num_buffers K-tiles + per chunk. M may be ragged (kernel OOB-clips, no divisibility needed). A cluster + also needs N divisible by cluster_n*tile_n. """ if N % ki.tile_n != 0 or K % (ki.split_k * ki.tile_k) != 0: return False diff --git a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py index f1efd210b3..c44424b81e 100644 --- a/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py +++ b/aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py @@ -66,30 +66,23 @@ def _s_prefetch_inst_burst(num_pages: int, page_bytes: int = 4096): def _make_tdm_desc(*, early_timeout=False, oob_outer_bound=None, **kwargs): - """Build a TDM descriptor across flydsl versions. + """Build a 2D TDM descriptor, transparently across flydsl versions.""" + strides = kwargs.get("strides") + runtime_stride = strides is not None and not isinstance(strides[0], int) + needs_oob = oob_outer_bound is not None - ``early_timeout`` / ``oob_outer_bound`` are passed to the native builder - only when it supports them. When ``oob_outer_bound`` is needed (ragged M) - but the installed flydsl lacks it, the vendored builder in ``tdm_oob`` is - used — it carries the OOB logic so the non-tile-aligned-M GEMM path works - against the older flydsl this build pins. - """ - if _TDM_HAS_OOB: - if _TDM_HAS_EARLY_TIMEOUT: - kwargs["early_timeout"] = early_timeout - return tdm_ops.make_tensor_descriptor_2d( - oob_outer_bound=oob_outer_bound, **kwargs + if runtime_stride or (needs_oob and not _TDM_HAS_OOB): + from .tdm_oob import make_tensor_descriptor_2d as _vendored_make_desc + + return _vendored_make_desc( + early_timeout=early_timeout, oob_outer_bound=oob_outer_bound, **kwargs ) - if oob_outer_bound is None: - if _TDM_HAS_EARLY_TIMEOUT: - kwargs["early_timeout"] = early_timeout - return tdm_ops.make_tensor_descriptor_2d(**kwargs) - # Vendored builder always supports early_timeout (pass it explicitly). - from .tdm_oob import make_tensor_descriptor_2d as _make_tensor_descriptor_2d_oob - - return _make_tensor_descriptor_2d_oob( - early_timeout=early_timeout, oob_outer_bound=oob_outer_bound, **kwargs - ) + + if _TDM_HAS_OOB: + kwargs["oob_outer_bound"] = oob_outer_bound + if _TDM_HAS_EARLY_TIMEOUT: + kwargs["early_timeout"] = early_timeout + return tdm_ops.make_tensor_descriptor_2d(**kwargs) # Common constants @@ -109,7 +102,6 @@ def compile_fp8fp4_gemm( *, data_format: str = "fp4", scale_mode: str = "mxscale", - M: int = 0, N: int = 0, K: int, tile_m: int = 128, @@ -148,7 +140,9 @@ def compile_fp8fp4_gemm( ptpc: scale_A [M], scale_B [N] fp32 Returns a JitFunction: - launch_fn(arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, M, N, stream) + launch_fn(arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, M, N, lda, ldc, stream) + where lda / ldc are the runtime leading-dim strides (in elements) of A / C. + Pass lda == K and ldc == N for dense (contiguous) tensors. """ if data_format not in ("fp4", "fp8", "a8w4"): raise ValueError( @@ -630,6 +624,8 @@ def kernel_mxscale_gemm( arg_b_scale: fx.Tensor, i32_m: fx.Int32, i32_n: fx.Int32, + i32_lda: fx.Int32, + i32_ldc: fx.Int32, ): # Enable back-to-back WMMA issue (SCHED_MODE bit[4] = DISABLE_VALU_STALL) rocdl.disable_xdl_arb_stall() @@ -711,7 +707,14 @@ def _bvs_prefetch(k_base): return a, b m_idx = fx.Index(i32_m) - n_stride = arith.index(N) + # Leading-dim strides arrive at runtime (strided A / C); the dense path + # passes lda == K and ldc == N, giving byte-identical addressing. A's + # stride is in packed-A elements (== lda for fp8 where PACK_FACTOR_A == 1). + if const_expr(PACK_FACTOR_A == 1): + lda_packed = fx.Index(i32_lda) + else: + lda_packed = fx.Index(i32_lda) / arith.index(PACK_FACTOR_A) + n_stride = fx.Index(i32_ldc) c_nrec = m_idx * n_stride * arith.index(elem_bytes_d) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) c_global_ptr_type = ir.Type.parse("!llvm.ptr<1>") @@ -729,7 +732,7 @@ def make_desc_a(memref, k_base): lds_memref=memref, global_offset=(blk_m, k_packed_off), tensor_shape=(tile_m, packed_tile_k_a), - strides=(K_packed_a, 1), + strides=(lda_packed, 1), tile_shape=(tile_m, packed_tile_k_a), elem_bytes=1, pad_interval=packed_tile_k_a, @@ -767,7 +770,7 @@ def make_desc_a_half(memref, k_base, m_half: int): lds_memref=memref, global_offset=(blk_m + arith.index(row_start), k_packed_off), tensor_shape=(tile_m, packed_tile_k_a), - strides=(K_packed_a, 1), + strides=(lda_packed, 1), tile_shape=(ab_split_a_rows, packed_tile_k_a), elem_bytes=1, pad_interval=packed_tile_k_a, @@ -2598,7 +2601,7 @@ def _l2_prefetch(k_base): lds_memref=d_lds_base_ptr, global_offset=(blk_m + warp_m_off_sgpr, blk_n + warp_n_off_sgpr), tensor_shape=(warp_tile_m, warp_tile_n), - strides=(N, 1), + strides=(n_stride, 1), tile_shape=(warp_tile_m, warp_tile_n), elem_bytes=elem_bytes_d, pad_interval=warp_tile_n, @@ -3233,6 +3236,8 @@ def launch_mxscale_gemm( arg_b_scale: fx.Tensor, i32_m: fx.Int32, i32_n: fx.Int32, + i32_lda: fx.Int32, + i32_ldc: fx.Int32, stream: fx.Stream, ): _ = cache_tag @@ -3259,6 +3264,8 @@ def launch_mxscale_gemm( arg_b_scale, i32_m, i32_n, + i32_lda, + i32_ldc, value_attrs={ "rocdl.waves_per_eu": effective_waves_per_eu, "rocdl.cluster_dims": ( @@ -3299,7 +3306,6 @@ def compile_a8w4_gemm(**kw): def compile_ptpc_gemm( *, - M: int = 0, N: int = 0, K: int, data_format: str = "fp8", @@ -3338,7 +3344,6 @@ def compile_ptpc_gemm( fp8_schedule="auto", scale_load_path="tdm", use_tdm_store=(split_k == 1), - M=M, N=N, K=K, tile_m=tile_m, diff --git a/aiter/ops/flydsl/kernels/tdm_oob.py b/aiter/ops/flydsl/kernels/tdm_oob.py index 5201114d9f..4e1e3ad867 100644 --- a/aiter/ops/flydsl/kernels/tdm_oob.py +++ b/aiter/ops/flydsl/kernels/tdm_oob.py @@ -83,6 +83,35 @@ def make_tensor_descriptor_2d( outer_tile, inner_tile = tile_shape outer_off, inner_off = global_offset + # The outer (leading-dim) stride may be a compile-time int or a runtime + # i32/index ir.Value (strided A/C, e.g. a row-slice whose row pitch exceeds + # the logical inner extent). Normalise to an index value for address math and + # an i32 value for the descriptor's stride field (sgpr5). + if isinstance(outer_stride, int): + outer_stride_idx = arith.index(outer_stride) + outer_stride_is_runtime = False + else: + os_val = ( + outer_stride.ir_value() + if hasattr(outer_stride, "ir_value") + else outer_stride + ) + if not isinstance(os_val, ir.Value): + raise TypeError( + f"outer stride must be int or i32/index ir.Value, " + f"got {type(outer_stride).__name__}" + ) + if isinstance(os_val.type, ir.IndexType): + # Wrap raw ir.Value so it supports the _ArithValue ops below (*, cast). + outer_stride_idx = _ArithValue(os_val) + elif isinstance(os_val.type, ir.IntegerType) and os_val.type.width == 32: + outer_stride_idx = arith.index_cast(T.index, os_val) + else: + raise TypeError( + f"outer stride ir.Value must be index or i32, got {os_val.type}" + ) + outer_stride_is_runtime = True + # -- Warp distribution -- warps_per_dim, block_per_warp = compute_warp_distribution( [outer_tile, inner_tile], @@ -112,7 +141,7 @@ def make_tensor_descriptor_2d( a_raw = global_ptr.__extract_to_ir_values__()[0] glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) - glb_elem_off = (outer_off + warp_off_outer) * arith.index(outer_stride) + ( + glb_elem_off = (outer_off + warp_off_outer) * outer_stride_idx + ( inner_off + warp_off_inner ) * arith.index(inner_stride) glb_byte_off = glb_elem_off * arith.index(elem_bytes) @@ -266,7 +295,11 @@ def make_tensor_descriptor_2d( g1_s4 = arith.constant(tile_d1 & 0xFFFF, type=T.i32) # sgpr5: tensor_dim0_stride (low 32 bits) — stride of outermost dim - g1_s5 = arith.constant(stride0 & 0xFFFFFFFF, type=T.i32) + if outer_stride_is_runtime: + # Runtime leading-dim stride: truncate the index to i32 (strides < 2^31). + g1_s5 = arith.index_cast(T.i32, outer_stride_idx) + else: + g1_s5 = arith.constant(stride0 & 0xFFFFFFFF, type=T.i32) # sgpr6-7: for 2D, no higher-dim strides g1_s6 = arith.constant(0, type=T.i32) diff --git a/aiter/ops/flydsl/utils.py b/aiter/ops/flydsl/utils.py index b3a1a49f55..ba6a60b00e 100644 --- a/aiter/ops/flydsl/utils.py +++ b/aiter/ops/flydsl/utils.py @@ -17,6 +17,8 @@ def addressable_lds_bytes_for_gfx(gfx: str) -> int: return _FALLBACK_MAX_LDS_BYTES if g.startswith("gfx950"): return 163840 + if g.startswith("gfx1250"): + return 327680 if g.startswith("gfx7") or g.startswith("gfx8"): return 32768 return 65536 diff --git a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py index ab7e3f6d1e..a0366634a5 100644 --- a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py +++ b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py @@ -3,11 +3,11 @@ """Correctness tests for the gfx1250 (WMMA) a8w8 bpreshuffle GEMM. -aiter.gemm_a8w8_bpreshuffle routes its FlyDSL path to the WMMA backend -(bpreshuffle_gemm_gfx1250) by the kernelName prefix ``flydsl_bpreshuffle_wmma_``. -Semantics are the ordinary a8w8 per-token (x_scale[M]) / per-channel (w_scale[N]) -fp8 GEMM, so inputs are quantized exactly like the standard a8w8 path. Skipped off -gfx1250. +On gfx1250, aiter.gemm_a8w8_bpreshuffle dispatches its FlyDSL path to the WMMA +backend (bpreshuffle_gemm_gfx1250); its tuned kernelNames carry the prefix +``flydsl_bpreshuffle_wmma_``. Semantics are the ordinary a8w8 per-token +(x_scale[M]) / per-channel (w_scale[N]) fp8 GEMM, so inputs are quantized exactly +like the standard a8w8 path. Skipped off gfx1250. """ import pytest @@ -15,6 +15,7 @@ import aiter from aiter.utility import dtypes +from aiter.test_common import checkAllclose from aiter.ops.shuffle import shuffle_weight from aiter.jit.utils.chip_info import get_gfx_runtime as get_gfx from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import wmma_kernel_name @@ -49,13 +50,6 @@ def _kernel_name( ) -def _metrics(out, ref): - out_f, ref_f = out.float(), ref.float() - rel = (out_f - ref_f).abs().sum() / ref_f.abs().sum().clamp_min(1e-6) - cos = torch.nn.functional.cosine_similarity(out_f.flatten(), ref_f.flatten(), dim=0) - return rel.item(), cos.item() - - def _quant(M, N, K): a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) * 2.0 b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) * 2.0 @@ -77,6 +71,15 @@ def _inject_tuned_config(monkeypatch, name): monkeypatch.setattr(gmod, "get_GEMM_config_with_quant_type", lambda *a, **k: config) +def _assert_close(ref, out, *, split_k=1, msg=""): + rtol = atol = 2e-2 if split_k > 1 else 1e-2 + bound = 0.10 if split_k > 1 else 0.05 + err = checkAllclose( + ref, out, rtol=rtol, atol=atol, msg=msg, catastrophic_check=True + ) + assert err <= bound, f"{msg}: {err:.1%} of elements exceed tol (bound {bound:.0%})" + + def test_kernel_name_roundtrips(): """Every catalogue kernelName must decode back to its config.""" from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import parse_wmma_kernel_name @@ -101,11 +104,24 @@ def test_kernel_name_roundtrips(): @pytest.mark.parametrize( "M,N,K", [ + # decode (small M; ragged -> kernel OOB-clips, no host pad) + (1, 4096, 4096), # M=1 decode + (2, 1280, 8192), # M=2 decode, qkv_proj-like + (4, 8192, 1024), # M=4 decode, attn_out-like + # square / balanced (256, 256, 256), (512, 1024, 512), - (1, 4096, 4096), # decode: M=1 (ragged, kernel OOB-clips; no host pad) - (333, 576, 1024), # ragged M, N=576 (tile_n=64 divides it) - (17, 64, 512), # tiny ragged M, small N + # ragged M (partial last M-tile) + (17, 64, 512), # tiny ragged M, min N (= tile_n=64) + (100, 256, 512), # ragged M + (333, 576, 1024), # ragged M, N=576 (9*tile_n) + # prefill / production projections + (128, 1280, 8192), # qkv_proj + (1024, 1280, 8192), # qkv_proj, large M + (512, 8192, 1024), # attn_out + (2048, 8192, 1024), # attn_out, large M + # large N + (64, 7424, 8192), # hipmm preshuffle (N=7424 = 116*tile_n) ], ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @@ -119,9 +135,7 @@ def test_gemm_a8w8_bpreshuffle_gfx1250(M, N, K, dtype, monkeypatch): assert out.shape == (M, N) assert out.dtype == dtype - rel, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low (M={M},N={N},K={K})" - assert rel < 0.05, f"rel L1={rel} too high (M={M},N={N},K={K})" + _assert_close(ref, out, msg=f"M={M}, N={N}, K={K}, dtype={dtype}") @pytest.mark.parametrize("num_buffers", [2, 3, 4]) @@ -134,8 +148,7 @@ def test_num_buffers(num_buffers, monkeypatch): out = aiter.gemm_a8w8_bpreshuffle( aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 ) - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low (num_buffers={num_buffers})" + _assert_close(ref, out, msg=f"num_buffers={num_buffers}") @pytest.mark.parametrize("split_k", [2, 4]) @@ -147,34 +160,51 @@ def test_split_k(split_k, monkeypatch): bq_sh = shuffle_weight(bq, layout=(16, 16)) ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) out = aiter.gemm_a8w8_bpreshuffle(aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16) - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low (split_k={split_k})" + _assert_close(ref, out, split_k=split_k, msg=f"split_k={split_k}") - # split-k must accumulate in fp32 + # split-k accumulates partial tiles via bf16 atomics, so it tracks split_k=1 + # closely (cos~1) but not bit-exactly. _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=1)) out_sk1 = aiter.gemm_a8w8_bpreshuffle( aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16 ) - rel, cos = _metrics(out, out_sk1) - assert rel < 1e-3, f"split_k={split_k} drifts from sk1 (rel L1={rel})" - assert cos > 0.9999, f"split_k={split_k} drifts from sk1 (cos={cos})" + _assert_close(out_sk1, out, split_k=split_k, msg=f"split_k={split_k} vs split_k=1") -@pytest.mark.parametrize("split_k", [2, 4]) -@pytest.mark.parametrize("M", [17, 100, 257]) +@pytest.mark.parametrize("split_k", [1, 2, 4]) +@pytest.mark.parametrize( + "M", + [ + 1, # extreme ragged: 1 of tile_m=128 rows valid (decode) + 17, # small sub-tile ragged + 100, # sub-tile ragged + 128, # aligned control (no remainder) + 257, # 2 full tiles + 1 ragged row + 300, # 2 full tiles + 44 ragged rows + 700, # 5 full tiles + 60 ragged rows + ], +) def test_ragged_m_split_k(M, split_k, monkeypatch): - """Ragged M with split-k exercises the per-lane (row < M) atomic predicate.""" - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=split_k)) + """Ragged M with split-k exercises the per-lane (row < M) atomic predicate. + + M spans sub-tile, tile-aligned, and multi-tile remainders against tile_m=128. + """ torch.manual_seed(0) N, K = 256, 1024 # K/(split_k*tile_k) integral; chunk holds >= 2 buffers aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = aiter.gemm_a8w8_bpreshuffle( - aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 - ) + bq_sh = shuffle_weight(bq, layout=(16, 16)) + + _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=split_k)) + out = aiter.gemm_a8w8_bpreshuffle(aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16) assert out.shape == (M, N) - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low (ragged M={M}, split_k={split_k})" + + _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=1)) + out_sk1 = aiter.gemm_a8w8_bpreshuffle( + aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16 + ) + _assert_close( + out_sk1, out, split_k=split_k, msg=f"ragged M={M} split_k={split_k} vs sk1" + ) @pytest.mark.parametrize( @@ -195,8 +225,7 @@ def test_warp_configs(m_warp, n_warp, tile_m, tile_n, monkeypatch): aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 ) assert out.shape == (M, N) - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} low (mw{m_warp} nw{n_warp} t{tile_m}x{tile_n})" + _assert_close(ref, out, msg=f"mw{m_warp} nw{n_warp} t{tile_m}x{tile_n}") def test_vendored_oob_path(monkeypatch): @@ -218,9 +247,7 @@ def test_vendored_oob_path(monkeypatch): aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 ) assert out.shape == (M, N) - rel, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low (vendored OOB path)" - assert rel < 0.05, f"rel L1={rel} too high (vendored OOB path)" + _assert_close(ref, out, msg="vendored OOB path") def test_cluster(monkeypatch): @@ -235,8 +262,7 @@ def test_cluster(monkeypatch): out = aiter.gemm_a8w8_bpreshuffle( aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 ) - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low" + _assert_close(ref, out, msg="cluster_m=2 cluster_n=2") def test_backend_direct_writes_out(): @@ -254,5 +280,71 @@ def test_backend_direct_writes_out(): aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, out, 128, 128, 128 ) assert ret.data_ptr() == out.data_ptr() - _, cos = _metrics(out, ref) - assert cos > 0.99, f"cosine={cos} too low" + _assert_close(ref, out, msg="backend direct out") + + +@pytest.mark.parametrize("split_k", [1, 2]) +@pytest.mark.parametrize("M", [128, 100]) +def test_strided_a_input(M, split_k, monkeypatch): + """A as a row-slice of a wider buffer (stride(0) > K, e.g. a DeepSeek fused + activation) must run in place -- no contiguous copy -- and match the dense + result. Exercises the runtime lda path through the vendored TDM descriptor.""" + _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=split_k)) + torch.manual_seed(0) + N, K = 256, 1024 # K/(split_k*tile_k) integral; chunk holds >= 2 buffers + aq, bq, a_scale, b_scale = _quant(M, N, K) + ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) + + pad = 64 # leading-dim padding -> stride(0) = K + pad + big = torch.empty(M, K + pad, dtype=aq.dtype, device="cuda") + big[:, :K].copy_(aq) + aq_strided = big[:, :K] + assert aq_strided.stride(0) == K + pad and aq_strided.stride(1) == 1 + assert not aq_strided.is_contiguous() + + out = aiter.gemm_a8w8_bpreshuffle( + aq_strided, + shuffle_weight(bq, layout=(16, 16)), + a_scale, + b_scale, + dtype=torch.bfloat16, + ) + assert out.shape == (M, N) + _assert_close( + ref, out, split_k=split_k, msg=f"strided A (M={M}, split_k={split_k})" + ) + + +@pytest.mark.parametrize("split_k", [1, 2]) +def test_strided_c_output(split_k, monkeypatch): + """Backend writes into a strided (column-sliced) Out without copying, and + leaves the leading-dim padding gap untouched. Exercises the runtime ldc path + for both the TDM store (split_k=1) and the atomic-add store (split_k>1).""" + from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import ( + run_preshuffle_gemm_a8_gfx1250, + ) + + torch.manual_seed(0) + M, N, K = 128, 256, 1024 + aq, bq, a_scale, b_scale = _quant(M, N, K) + ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) + + pad = 64 # Out is a column-slice of a [M, N+pad] buffer -> stride(0) = N + pad + big = torch.full((M, N + pad), -1.0, dtype=torch.bfloat16, device="cuda") + out = big[:, :N] + assert out.stride(0) == N + pad and not out.is_contiguous() + + ret = run_preshuffle_gemm_a8_gfx1250( + aq, + shuffle_weight(bq, layout=(16, 16)), + a_scale, + b_scale, + out, + 128, + 128, + 128, + split_k=split_k, + ) + assert ret.data_ptr() == out.data_ptr() # wrote in place, no copy-back + _assert_close(ref, out, split_k=split_k, msg=f"strided C (split_k={split_k})") + assert torch.all(big[:, N:] == -1.0), "kernel wrote into the C padding gap" From ef7468a550f0b09d1336a5b0202d7a2d0a6b7ca1 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Sun, 7 Jun 2026 14:49:15 +0000 Subject: [PATCH 10/19] test(ptpc gemm): main-driven correctness + benchmark for gfx1250 WMMA aiter runs op_tests via `python3 ` (not pytest), so keep a dedicated main-driven test for the gfx1250 PTPC FP8 bpreshuffle backend, styled after op_tests/test_gemm_a8w8.py: @perftest-timed runner + @benchmark test_gemm into a DataFrame / markdown table, correctness in err/pass columns. - run_sweep: -mnk x -d shape sweep (real model shapes). - run_features: ragged M, strided A/C, split-k, m_warp/n_warp + cluster configs, vendored OOB descriptor, via test_gemm variants. Timing uses run_perftest(use_cuda_event=True) (the FlyDSL kernel is JIT-dispatched, which torch.profiler/ROCTracer can miss). Args mirror test_gemm_a8w8 (-d / -mnk / -o / --suffix) plus tile / split_k / num_buffers / m_warp / n_warp / --no-features. Skipped off gfx1250. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../test_gemm_a8w8_bpreshuffle_gfx1250.py | 589 ++++++++---------- 1 file changed, 273 insertions(+), 316 deletions(-) diff --git a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py index a0366634a5..fae128b217 100644 --- a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py +++ b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py @@ -1,60 +1,44 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -"""Correctness tests for the gfx1250 (WMMA) a8w8 bpreshuffle GEMM. - -On gfx1250, aiter.gemm_a8w8_bpreshuffle dispatches its FlyDSL path to the WMMA -backend (bpreshuffle_gemm_gfx1250); its tuned kernelNames carry the prefix -``flydsl_bpreshuffle_wmma_``. Semantics are the ordinary a8w8 per-token -(x_scale[M]) / per-channel (w_scale[N]) fp8 GEMM, so inputs are quantized exactly -like the standard a8w8 path. Skipped off gfx1250. +"""Correctness + benchmark for the gfx1250 (WMMA) a8w8 bpreshuffle GEMM. + +Main-driven like op_tests/test_gemm_a8w8.py (aiter runs op_tests via python3, +not pytest). Sweeps -mnk x -d, plus a feature pass (ragged M, strided A/C, +split-k, warp/cluster configs, vendored OOB descriptor). Each row times the +WMMA backend with run_perftest(use_cuda_event=True) -- the FlyDSL kernel is +JIT-dispatched (hipModuleLaunchKernel), which torch.profiler can miss -- and +checks correctness vs the dense reference (err / pass columns). Skipped off +gfx1250. """ -import pytest +import argparse +import os +import sys + +import pandas as pd import torch import aiter from aiter.utility import dtypes -from aiter.test_common import checkAllclose +from aiter.test_common import checkAllclose, perftest, benchmark from aiter.ops.shuffle import shuffle_weight from aiter.jit.utils.chip_info import get_gfx_runtime as get_gfx -from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import wmma_kernel_name +from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import run_preshuffle_gemm_a8_gfx1250 -pytestmark = pytest.mark.skipif( - not torch.cuda.is_available() or get_gfx() != "gfx1250", - reason="gfx1250 WMMA a8w8 bpreshuffle requires a gfx1250 device", -) +TEST_NUM_ITERS = 50 +_DTYPE = {"bf16": torch.bfloat16, "f16": torch.float16} -def _kernel_name( - tile_m, - tile_n, - tile_k, - num_buffers, - split_k=1, - cluster_m=1, - cluster_n=1, - m_warp=2, - n_warp=2, -): - return wmma_kernel_name( - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - num_buffers=num_buffers, - split_k=split_k, - cluster_m=cluster_m, - cluster_n=cluster_n, - m_warp=m_warp, - n_warp=n_warp, - ) +def _dt(dtype): + return "bf16" if dtype == torch.bfloat16 else "f16" -def _quant(M, N, K): - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) * 2.0 - b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) * 2.0 - aq, a_scale = aiter.pertoken_quant(a, quant_dtype=dtypes.fp8) # [M, 1] - bq, b_scale = aiter.pertoken_quant(b, quant_dtype=dtypes.fp8) # [N, 1] +def _quant(m, n, k): + a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) * 2.0 + b = torch.randn(n, k, device="cuda", dtype=torch.bfloat16) * 2.0 + aq, a_scale = aiter.pertoken_quant(a, quant_dtype=dtypes.fp8) + bq, b_scale = aiter.pertoken_quant(b, quant_dtype=dtypes.fp8) return aq, bq, a_scale, b_scale @@ -64,287 +48,260 @@ def _ref(aq, bq, a_scale, b_scale, dtype): return (a_f @ b_f.t()).to(dtype) -def _inject_tuned_config(monkeypatch, name): - import aiter.ops.gemm_op_a8w8 as gmod - - config = {"libtype": "flydsl", "splitK": 1, "kernelName": name} - monkeypatch.setattr(gmod, "get_GEMM_config_with_quant_type", lambda *a, **k: config) +def _bound(split_k): + return 0.10 if split_k > 1 else 0.05 -def _assert_close(ref, out, *, split_k=1, msg=""): - rtol = atol = 2e-2 if split_k > 1 else 1e-2 - bound = 0.10 if split_k > 1 else 0.05 - err = checkAllclose( - ref, out, rtol=rtol, atol=atol, msg=msg, catastrophic_check=True - ) - assert err <= bound, f"{msg}: {err:.1%} of elements exceed tol (bound {bound:.0%})" - - -def test_kernel_name_roundtrips(): - """Every catalogue kernelName must decode back to its config.""" - from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import parse_wmma_kernel_name - from aiter.ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( - kernels_list, +@perftest(num_iters=TEST_NUM_ITERS, num_rotate_args=1, use_cuda_event=True) +def run_gemm_flydsl( + xq, + wq_sh, + x_scale, + w_scale, + out, + tile, + num_buffers, + split_k, + m_warp, + n_warp, + cluster, +): + run_preshuffle_gemm_a8_gfx1250( + xq, + wq_sh, + x_scale, + w_scale, + out, + *tile, + num_buffers=num_buffers, + split_k=split_k, + m_warp=m_warp, + n_warp=n_warp, + cluster_m=cluster[0], + cluster_n=cluster[1], ) - - assert kernels_list, "WMMA catalogue is empty" - for ki in kernels_list.values(): - cfg = parse_wmma_kernel_name(ki.name) - assert cfg is not None, f"cannot parse {ki.name}" - assert (cfg["tile_m"], cfg["tile_n"], cfg["tile_k"]) == ( - ki.tile_m, - ki.tile_n, - ki.tile_k, + return out + + +@benchmark() +def test_gemm( + dtype, + m, + n, + k, + tile_m=128, + tile_n=128, + tile_k=128, + num_buffers=4, + split_k=1, + m_warp=2, + n_warp=2, + cluster_m=1, + cluster_n=1, + variant="dense", +): + torch.manual_seed(0) + xq, wq, x_scale, w_scale = _quant(m, n, k) + wq_sh = shuffle_weight(wq, layout=(16, 16)) + ref = _ref(xq, wq, x_scale, w_scale, dtype) + + xin = xq + if variant == "strided_a": + big = torch.empty(m, k + 64, dtype=xq.dtype, device="cuda") + big[:, :k].copy_(xq) + xin = big[:, :k] + if variant == "strided_c": + out = torch.full((m, n + 64), -1.0, dtype=dtype, device="cuda")[:, :n] + else: + out = torch.empty(m, n, dtype=dtype, device="cuda") + + kmod = None + if variant == "oob": + import aiter.ops.flydsl.kernels.gemm_fp8fp4_gfx1250 as kmod + + saved, kmod._TDM_HAS_OOB = kmod._TDM_HAS_OOB, False + try: + out, us = run_gemm_flydsl( + xin, + wq_sh, + x_scale, + w_scale, + out, + (tile_m, tile_n, tile_k), + num_buffers, + split_k, + m_warp, + n_warp, + (cluster_m, cluster_n), ) - assert cfg["num_buffers"] == ki.num_buffers and cfg["split_k"] == ki.split_k - assert cfg["cluster_m"] == ki.cluster_m and cfg["cluster_n"] == ki.cluster_n - assert cfg["m_warp"] == ki.m_warp and cfg["n_warp"] == ki.n_warp - - -@pytest.mark.parametrize( - "M,N,K", - [ - # decode (small M; ragged -> kernel OOB-clips, no host pad) - (1, 4096, 4096), # M=1 decode - (2, 1280, 8192), # M=2 decode, qkv_proj-like - (4, 8192, 1024), # M=4 decode, attn_out-like - # square / balanced - (256, 256, 256), - (512, 1024, 512), - # ragged M (partial last M-tile) - (17, 64, 512), # tiny ragged M, min N (= tile_n=64) - (100, 256, 512), # ragged M - (333, 576, 1024), # ragged M, N=576 (9*tile_n) - # prefill / production projections - (128, 1280, 8192), # qkv_proj - (1024, 1280, 8192), # qkv_proj, large M - (512, 8192, 1024), # attn_out - (2048, 8192, 1024), # attn_out, large M - # large N - (64, 7424, 8192), # hipmm preshuffle (N=7424 = 116*tile_n) - ], + finally: + if kmod is not None: + kmod._TDM_HAS_OOB = saved + + tol = 2e-2 if split_k > 1 else 1e-2 + err = float(checkAllclose(ref, out, rtol=tol, atol=tol, printLog=False)) + return { + "us": round(us, 2), + "tflops": round(2 * m * n * k / us / 1e6, 1), + "err": round(err, 4), + "pass": bool(err <= _bound(split_k)), + } + + +def _report(title, rows): + df = pd.DataFrame(rows) + aiter.logger.info("%s:\n%s", title, df.to_markdown(index=False)) + print(f"\n== {title} ==\n{df.to_markdown(index=False)}") + return df + + +def run_sweep(l_dtype, l_mnk, **cfg): + rows = [] + for dtype in l_dtype: + for m, n, k in l_mnk: + try: + rows.append(test_gemm(dtype, m, n, k, **cfg)) + except Exception as e: # noqa: BLE001 + rows.append( + { + "dtype": dtype, + "m": m, + "n": n, + "k": k, + "pass": False, + "note": str(e)[:50], + } + ) + return _report("shape sweep", rows) + + +def run_features(l_dtype): + dt = l_dtype[0] + cases = [ + dict(m=17, n=256, k=1024, num_buffers=2, variant="dense"), + dict(m=100, n=256, k=1024, num_buffers=2, split_k=2, variant="dense"), + dict(m=257, n=256, k=1024, num_buffers=2, split_k=4, variant="dense"), + dict(m=128, n=256, k=1024, num_buffers=2, variant="strided_a"), + dict(m=128, n=256, k=1024, num_buffers=2, split_k=2, variant="strided_c"), + dict(m=100, n=256, k=512, tile_m=64, tile_n=64, num_buffers=2, variant="oob"), + dict( + m=1, + n=256, + k=512, + tile_m=16, + tile_n=32, + tile_k=256, + num_buffers=2, + m_warp=1, + n_warp=2, + ), + dict( + m=1, + n=256, + k=512, + tile_m=32, + tile_n=64, + tile_k=256, + num_buffers=2, + m_warp=2, + n_warp=4, + ), + dict(m=512, n=512, k=512, num_buffers=2, cluster_m=2, cluster_n=2), + ] + rows = [] + for c in cases: + try: + rows.append(test_gemm(dt, **c)) + except Exception as e: # noqa: BLE001 + rows.append({"dtype": dt, **c, "pass": False, "note": str(e)[:50]}) + return _report("feature checks", rows) + + +_DEFAULT_MNK = [ + (1, 1280, 8192), + (32, 1280, 8192), + (64, 1280, 8192), + (128, 1280, 8192), + (192, 1280, 8192), + (256, 1280, 8192), + (320, 1280, 8192), + (512, 1280, 8192), + (1024, 1280, 8192), + (2048, 1280, 8192), + (4096, 1280, 8192), + (8192, 1280, 8192), + (16384, 1280, 8192), + (1, 8192, 1024), + (32, 8192, 1024), + (64, 8192, 1024), + (128, 8192, 1024), + (192, 8192, 1024), + (256, 8192, 1024), + (320, 8192, 1024), + (512, 8192, 1024), + (1024, 8192, 1024), + (2048, 8192, 1024), + (4096, 8192, 1024), + (8192, 8192, 1024), + (16384, 8192, 1024), + (16, 7424, 8192), + (32, 7424, 8192), + (48, 7424, 8192), + (64, 7424, 8192), + (4096, 7424, 8192), + (5120, 7424, 8192), + (8192, 7424, 8192), +] + +parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + description="gfx1250 FlyDSL PTPC FP8 a8w8 bpreshuffle GEMM correctness + benchmark.", ) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -def test_gemm_a8w8_bpreshuffle_gfx1250(M, N, K, dtype, monkeypatch): - _inject_tuned_config(monkeypatch, _kernel_name(128, 64, 128, 2)) - torch.manual_seed(0) - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, dtype) - bq_prepared = shuffle_weight(bq, layout=(16, 16)) - out = aiter.gemm_a8w8_bpreshuffle(aq, bq_prepared, a_scale, b_scale, dtype=dtype) - - assert out.shape == (M, N) - assert out.dtype == dtype - _assert_close(ref, out, msg=f"M={M}, N={N}, K={K}, dtype={dtype}") - - -@pytest.mark.parametrize("num_buffers", [2, 3, 4]) -def test_num_buffers(num_buffers, monkeypatch): - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, num_buffers)) - torch.manual_seed(0) - M, N, K = 256, 256, 1024 # K/tile_k = 8 >= 4 buffers - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = aiter.gemm_a8w8_bpreshuffle( - aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 - ) - _assert_close(ref, out, msg=f"num_buffers={num_buffers}") - - -@pytest.mark.parametrize("split_k", [2, 4]) -def test_split_k(split_k, monkeypatch): - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=split_k)) - torch.manual_seed(0) - M, N, K = 256, 256, 1024 - aq, bq, a_scale, b_scale = _quant(M, N, K) - bq_sh = shuffle_weight(bq, layout=(16, 16)) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = aiter.gemm_a8w8_bpreshuffle(aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16) - _assert_close(ref, out, split_k=split_k, msg=f"split_k={split_k}") - - # split-k accumulates partial tiles via bf16 atomics, so it tracks split_k=1 - # closely (cos~1) but not bit-exactly. - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=1)) - out_sk1 = aiter.gemm_a8w8_bpreshuffle( - aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16 - ) - _assert_close(out_sk1, out, split_k=split_k, msg=f"split_k={split_k} vs split_k=1") - - -@pytest.mark.parametrize("split_k", [1, 2, 4]) -@pytest.mark.parametrize( - "M", - [ - 1, # extreme ragged: 1 of tile_m=128 rows valid (decode) - 17, # small sub-tile ragged - 100, # sub-tile ragged - 128, # aligned control (no remainder) - 257, # 2 full tiles + 1 ragged row - 300, # 2 full tiles + 44 ragged rows - 700, # 5 full tiles + 60 ragged rows - ], +parser.add_argument( + "-d", "--dtype", nargs="*", choices=list(_DTYPE), default=list(_DTYPE) ) -def test_ragged_m_split_k(M, split_k, monkeypatch): - """Ragged M with split-k exercises the per-lane (row < M) atomic predicate. - - M spans sub-tile, tile-aligned, and multi-tile remainders against tile_m=128. - """ - torch.manual_seed(0) - N, K = 256, 1024 # K/(split_k*tile_k) integral; chunk holds >= 2 buffers - aq, bq, a_scale, b_scale = _quant(M, N, K) - bq_sh = shuffle_weight(bq, layout=(16, 16)) - - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=split_k)) - out = aiter.gemm_a8w8_bpreshuffle(aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16) - assert out.shape == (M, N) - - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=1)) - out_sk1 = aiter.gemm_a8w8_bpreshuffle( - aq, bq_sh, a_scale, b_scale, dtype=torch.bfloat16 - ) - _assert_close( - out_sk1, out, split_k=split_k, msg=f"ragged M={M} split_k={split_k} vs sk1" - ) - - -@pytest.mark.parametrize( - "m_warp,n_warp,tile_m,tile_n", - [(1, 2, 16, 32), (1, 4, 16, 64), (2, 2, 32, 32), (2, 4, 32, 64)], +parser.add_argument("-mnk", type=dtypes.str2tuple, nargs="*", default=_DEFAULT_MNK) +parser.add_argument("--tile_m", type=int, default=128) +parser.add_argument("--tile_n", type=int, default=128) +parser.add_argument("--tile_k", type=int, default=128) +parser.add_argument("--num_buffers", type=int, default=4) +parser.add_argument("--split_k", type=int, default=1) +parser.add_argument("--m_warp", type=int, default=2) +parser.add_argument("--n_warp", type=int, default=2) +parser.add_argument( + "--no-features", action="store_true", help="Skip the feature checks." ) -def test_warp_configs(m_warp, n_warp, tile_m, tile_n, monkeypatch): - """m_warp/n_warp are tunable; m_warp=1 small-tile_m configs serve decode.""" - _inject_tuned_config( - monkeypatch, - _kernel_name(tile_m, tile_n, 256, 2, m_warp=m_warp, n_warp=n_warp), - ) - torch.manual_seed(0) - M, N, K = 1, 256, 512 - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = aiter.gemm_a8w8_bpreshuffle( - aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 - ) - assert out.shape == (M, N) - _assert_close(ref, out, msg=f"mw{m_warp} nw{n_warp} t{tile_m}x{tile_n}") - - -def test_vendored_oob_path(monkeypatch): - """Force the vendored OOB descriptor builder (the older-flydsl fallback) and - verify ragged-M correctness, even when the installed flydsl has native OOB. - - Uses a tile config no other test compiles so the cached kernel is built fresh - through the vendored path while ``_TDM_HAS_OOB`` is patched off. - """ - import aiter.ops.flydsl.kernels.gemm_fp8fp4_gfx1250 as kmod - - monkeypatch.setattr(kmod, "_TDM_HAS_OOB", False) - _inject_tuned_config(monkeypatch, _kernel_name(64, 64, 128, 2)) - torch.manual_seed(0) - M, N, K = 100, 256, 512 # ragged M -> partial last M-tile via vendored desc - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = aiter.gemm_a8w8_bpreshuffle( - aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 - ) - assert out.shape == (M, N) - _assert_close(ref, out, msg="vendored OOB path") - - -def test_cluster(monkeypatch): - """Workgroup cluster (cluster_m/n>1) over an evenly divisible grid.""" - _inject_tuned_config( - monkeypatch, _kernel_name(128, 128, 128, 2, cluster_m=2, cluster_n=2) - ) - torch.manual_seed(0) - M, N, K = 512, 512, 512 # grid (4, 4) divisible by cluster (2, 2) - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = aiter.gemm_a8w8_bpreshuffle( - aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, dtype=torch.bfloat16 - ) - _assert_close(ref, out, msg="cluster_m=2 cluster_n=2") - - -def test_backend_direct_writes_out(): - """The gfx1250 backend writes into the caller's Out tensor.""" - from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import ( - run_preshuffle_gemm_a8_gfx1250, - ) - - torch.manual_seed(0) - M, N, K = 512, 512, 512 - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - out = torch.empty(M, N, device="cuda", dtype=torch.bfloat16) - ret = run_preshuffle_gemm_a8_gfx1250( - aq, shuffle_weight(bq, layout=(16, 16)), a_scale, b_scale, out, 128, 128, 128 - ) - assert ret.data_ptr() == out.data_ptr() - _assert_close(ref, out, msg="backend direct out") - - -@pytest.mark.parametrize("split_k", [1, 2]) -@pytest.mark.parametrize("M", [128, 100]) -def test_strided_a_input(M, split_k, monkeypatch): - """A as a row-slice of a wider buffer (stride(0) > K, e.g. a DeepSeek fused - activation) must run in place -- no contiguous copy -- and match the dense - result. Exercises the runtime lda path through the vendored TDM descriptor.""" - _inject_tuned_config(monkeypatch, _kernel_name(128, 128, 128, 2, split_k=split_k)) - torch.manual_seed(0) - N, K = 256, 1024 # K/(split_k*tile_k) integral; chunk holds >= 2 buffers - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - - pad = 64 # leading-dim padding -> stride(0) = K + pad - big = torch.empty(M, K + pad, dtype=aq.dtype, device="cuda") - big[:, :K].copy_(aq) - aq_strided = big[:, :K] - assert aq_strided.stride(0) == K + pad and aq_strided.stride(1) == 1 - assert not aq_strided.is_contiguous() - - out = aiter.gemm_a8w8_bpreshuffle( - aq_strided, - shuffle_weight(bq, layout=(16, 16)), - a_scale, - b_scale, - dtype=torch.bfloat16, - ) - assert out.shape == (M, N) - _assert_close( - ref, out, split_k=split_k, msg=f"strided A (M={M}, split_k={split_k})" - ) - - -@pytest.mark.parametrize("split_k", [1, 2]) -def test_strided_c_output(split_k, monkeypatch): - """Backend writes into a strided (column-sliced) Out without copying, and - leaves the leading-dim padding gap untouched. Exercises the runtime ldc path - for both the TDM store (split_k=1) and the atomic-add store (split_k>1).""" - from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import ( - run_preshuffle_gemm_a8_gfx1250, +parser.add_argument("-o", "--output", type=str, default=None) +parser.add_argument("--suffix", type=str, default="results") +args = parser.parse_args() + +if not torch.cuda.is_available() or get_gfx() != "gfx1250": + print(f"Skipping: requires gfx1250 (current: {get_gfx()})") + sys.exit(0) + +l_dtype = [_DTYPE[d] for d in args.dtype] +df = run_sweep( + l_dtype, + args.mnk, + tile_m=args.tile_m, + tile_n=args.tile_n, + tile_k=args.tile_k, + num_buffers=args.num_buffers, + split_k=args.split_k, + m_warp=args.m_warp, + n_warp=args.n_warp, +) +dfs = [df] +if not args.no_features: + dfs.append(run_features(l_dtype)) + +if args.output: + os.makedirs(args.output, exist_ok=True) + out_path = os.path.join( + args.output, f"gemm_a8w8_bpreshuffle_gfx1250_{args.suffix}.csv" ) + pd.concat(dfs, ignore_index=True).to_csv(out_path, index=False) + print(f"Saved results to: {out_path}") - torch.manual_seed(0) - M, N, K = 128, 256, 1024 - aq, bq, a_scale, b_scale = _quant(M, N, K) - ref = _ref(aq, bq, a_scale, b_scale, torch.bfloat16) - - pad = 64 # Out is a column-slice of a [M, N+pad] buffer -> stride(0) = N + pad - big = torch.full((M, N + pad), -1.0, dtype=torch.bfloat16, device="cuda") - out = big[:, :N] - assert out.stride(0) == N + pad and not out.is_contiguous() - - ret = run_preshuffle_gemm_a8_gfx1250( - aq, - shuffle_weight(bq, layout=(16, 16)), - a_scale, - b_scale, - out, - 128, - 128, - 128, - split_k=split_k, - ) - assert ret.data_ptr() == out.data_ptr() # wrote in place, no copy-back - _assert_close(ref, out, split_k=split_k, msg=f"strided C (split_k={split_k})") - assert torch.all(big[:, N:] == -1.0), "kernel wrote into the C padding gap" +n_fail = sum(int((~d["pass"]).sum()) for d in dfs) +sys.exit(1 if n_fail else 0) From db11d2186bf76167d053b850faab9706a435ba7f Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 8 Jun 2026 00:56:14 +0000 Subject: [PATCH 11/19] fold gfx1250 WMMA ptpc tuning into a8w8 bpreshuffle flydsl tuner --- .../gemm_a8w8_bpreshuffle_wmma_tune.py | 287 ------------------ .../gemm_a8w8_bpreshuffle_tune.py | 89 ++++++ 2 files changed, 89 insertions(+), 287 deletions(-) delete mode 100644 aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py diff --git a/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py b/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py deleted file mode 100644 index 154b617cd2..0000000000 --- a/aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py +++ /dev/null @@ -1,287 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -"""Auto-tuner for the gfx1250 (WMMA) a8w8 bpreshuffle GEMM. - -FlyDSL-only counterpart of the MFMA tuner in -``csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py``: for each untuned -(M, N, K) it benchmarks the WMMA candidates from -``flydsl_gemm_a8w8_bpreshuffle_wmma_common`` and writes the winner's kernelName -(``flydsl_bpreshuffle_wmma_*``, libtype "flydsl") into the a8w8 bpreshuffle tuned -CSV. The public ``aiter.gemm_a8w8_bpreshuffle`` op then routes to it on gfx1250. - -Usage:: - - python aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py \ - --untune_file aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv \ - --tune_file aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv -""" - -import pandas as pd -import torch - -import aiter -from aiter import dtypes -from aiter.jit.core import AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE, AITER_ROOT_DIR -from aiter.utility.base_tuner import GemmCommonTuner -from aiter.utility.mp_tuner import mp_tuner -from aiter.ops.shuffle import shuffle_weight -from aiter.ops.flydsl.utils import is_flydsl_available -from aiter.ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( - kernels_list, - kernel_fits_shape, -) - -if is_flydsl_available(): - from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import run_preshuffle_gemm_a8_gfx1250 - -_OUT_TORCH = {"bf16": torch.bfloat16, "f16": torch.float16} -_Q_DTYPE_W = str(dtypes.fp8) - - -def run_torch(x, weight, x_scale, w_scale, dtype=torch.bfloat16): - a = x.to(torch.float32) * x_scale.to(torch.float32) - b = weight.to(torch.float32) * w_scale.to(torch.float32) - return (a @ b.t()).to(dtype) - - -def generate_data(m, n, k, seed, out_dtype="bf16", device="cuda"): - torch.manual_seed(seed) - x = torch.randn((m, k), dtype=torch.bfloat16, device=device) - weight = torch.randn((n, k), dtype=torch.bfloat16, device=device) - x, x_scale = aiter.pertoken_quant(x, quant_dtype=dtypes.fp8) - weight, w_scale = aiter.pertoken_quant(weight, quant_dtype=dtypes.fp8) - weight_shuffle = shuffle_weight(weight, layout=(16, 16)) - out = torch.empty(m, n, dtype=_OUT_TORCH[out_dtype], device=device) - return { - "x": x, - "weight": weight, - "weight_shuffle": weight_shuffle, - "x_scale": x_scale, - "w_scale": w_scale, - "out": out, - } - - -def run_gemm(x, weight_shuffle, x_scale, w_scale, out, kernel_id): - ki = kernels_list[kernel_id] - run_preshuffle_gemm_a8_gfx1250( - x, - weight_shuffle, - x_scale, - w_scale, - out, - ki.tile_m, - ki.tile_n, - ki.tile_k, - num_buffers=ki.num_buffers, - split_k=ki.split_k, - cluster_m=ki.cluster_m, - cluster_n=ki.cluster_n, - m_warp=ki.m_warp, - n_warp=ki.n_warp, - ) - return out - - -class GemmA8W8BpreShuffleWmmaTuner(GemmCommonTuner): - ARG_DEFAULTS = { - **GemmCommonTuner.ARG_DEFAULTS, - "tune_file": f"{AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE}", - "untune_file": f"{AITER_ROOT_DIR}/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv", - "config_env_name": "AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE", - } - - def _clear_op_caches(self): - from aiter.ops import gemm_op_a8w8 as _op - - _op.get_GEMM_config_with_quant_type.cache_clear() - _op._GEMM_QUANT_TYPE_CACHE.clear() - _op._GEMM_QUANT_TYPE_HAS_GFX.clear() - - def _setup_specific_arguments(self): - self.parser.add_argument( - "--out_dtype", - type=str, - default="bf16", - choices=["bf16", "f16"], - help="Output dtype to tune (run once per dtype your model needs)", - ) - - def calculate(self, results, bpes=(1, 1, 2)): - return super().calculate(results, bpes=bpes) - - def getKernelName(self, kernelId, libtype="flydsl"): - ki = kernels_list.get(kernelId) - return ki.name if ki is not None else None - - def result_to_df(self, results): - resultdf = pd.DataFrame(columns=self.columns) - for el in results: - info, time, err_ratio = el - keys, kernelId, splitK, kernelName, libtype = info - if time == self.INVALID_TIME: - kernelName = "None" - elif kernelName == "": - resolved = self.getKernelName(kernelId, libtype) - kernelName = "None" if resolved is None else str(resolved) - tflops, bw = self.calculate(el) - key_dict = dict(zip(self.keys, keys)) - key_dict.update( - { - "libtype": [libtype], - "kernelId": [kernelId], - "splitK": [splitK], - "us": [time], - "kernelName": [kernelName], - "errRatio": [err_ratio], - "tflops": [tflops], - "bw": [bw], - } - ) - temp = pd.DataFrame(key_dict) - resultdf = ( - temp - if resultdf.empty - else pd.concat([resultdf, temp], ignore_index=True) - ) - return resultdf - - def get_wmma_tune_task(self, info_keys, out_dtype, seed, args): - gfx, cu_num, M, N, K, q_dtype_w = info_keys - if not is_flydsl_available(): - return [] - run_keys = ["x", "weight_shuffle", "x_scale", "w_scale", "out"] - ref_keys = ["x", "weight", "x_scale", "w_scale"] - tasks = [] - for i in sorted(kernels_list.keys()): - ki = kernels_list[i] - if not kernel_fits_shape(ki, M, N, K): - continue - info = (info_keys, i, 0, ki.name, "flydsl") - tasks.append( - ( - info, - generate_data, - (M, N, K, seed, out_dtype), - run_gemm, - (run_keys, i), - {"num_warmup": args.warmup, "num_iters": args.iters}, - run_torch, - (ref_keys, _OUT_TORCH[out_dtype]), - {}, - None, - 1e-2, - 1e-2, - None, - None, - ("out",), - ) - ) - return tasks - - def tune(self, untunedf, tunedf, args): - mp_num = args.mp - shape_grouped = args.shape_grouped - errRatio = args.errRatio - cu_num = self.get_cu_num() - gfx = self.get_gfx() - out_dtype = args.out_dtype - task = [] - tasks_data = [] - seed = 0 - already = set() - if tunedf is not None and not tunedf.empty and "q_dtype_w" in tunedf.columns: - sub = tunedf[tunedf["q_dtype_w"] == _Q_DTYPE_W] - already = {(int(r.M), int(r.N), int(r.K)) for r in sub.itertuples()} - for i in range(len(untunedf)): - M = untunedf.loc[i, "M"] - N = untunedf.loc[i, "N"] - K = untunedf.loc[i, "K"] - if ( - "q_dtype_w" in untunedf.columns - and str(untunedf.loc[i, "q_dtype_w"]) != _Q_DTYPE_W - ): - continue - if (int(M), int(N), int(K)) in already: - continue - seed += 1 - prev = len(task) - info_keys = (gfx, cu_num, M, N, K, _Q_DTYPE_W) - task.extend(self.get_wmma_tune_task(info_keys, out_dtype, seed, args)) - tasks_data.append((len(task) - prev, ())) - ret = [] - if task: - ret = mp_tuner( - task, - tasks_data, - mp_num, - False, - shape_grouped, - errRatio, - timeout=args.timeout, - verbose=args.verbose, - ) - return ret - - def run_config(self, args): - from aiter.test_common import run_perftest, checkAllclose - - out_dtype = args.out_dtype - allowed = args.errRatio - results = [] - for i in range(len(self.untunedf)): - row = self.untunedf.iloc[i] - if "q_dtype_w" in row and str(row["q_dtype_w"]) != _Q_DTYPE_W: - continue - M, N, K = int(row["M"]), int(row["N"]), int(row["K"]) - shape_str = f"M{M}_N{N}_K{K}_fp8_{out_dtype}" - try: - d = generate_data(M, N, K, seed=0, out_dtype=out_dtype) - ref = run_torch( - d["x"], - d["weight"], - d["x_scale"], - d["w_scale"], - _OUT_TORCH[out_dtype], - ) - out, us = run_perftest( - aiter.gemm_a8w8_bpreshuffle, - d["x"], - d["weight_shuffle"], - d["x_scale"], - d["w_scale"], - dtype=_OUT_TORCH[out_dtype], - ) - err = checkAllclose(out, ref, msg=f"run_config {shape_str}") - status = "ok" if err <= allowed else f"mismatch:err_ratio={err:.6g}" - results.append({"shape": shape_str, "e2e_us": us, "status": status}) - except Exception as e: # noqa: BLE001 - results.append( - {"shape": shape_str, "e2e_us": -1, "status": f"error:{e}"} - ) - finally: - torch.cuda.empty_cache() - return results - - -if __name__ == "__main__": - key = ["gfx", "cu_num", "M", "N", "K", "q_dtype_w"] - resultList = [ - "libtype", - "kernelId", - "splitK", - "us", - "kernelName", - "tflops", - "bw", - "errRatio", - ] - tuner = GemmA8W8BpreShuffleWmmaTuner( - "GemmA8W8BpreShuffleWmmaTuner", - key=key, - resultList=resultList, - description="Auto-tuner for gfx1250 (WMMA) a8w8 bpreshuffle GEMM", - ) - args = tuner.parse_args() - tuner.run(args, False) diff --git a/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py b/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py index bd9587b867..a87360b5f5 100755 --- a/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py +++ b/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py @@ -155,6 +155,34 @@ def run_gemm_flydsl(x, weight_shuffle, x_scale, w_scale, out, kernel_id): return out +def run_gemm_flydsl_gfx1250(x, weight_shuffle, x_scale, w_scale, out, kernel_id): + from aiter.ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( + kernels_list as kernels_list_flydsl_wmma, + ) + from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import ( + run_preshuffle_gemm_a8_gfx1250, + ) + + ki = kernels_list_flydsl_wmma[kernel_id] + run_preshuffle_gemm_a8_gfx1250( + x, + weight_shuffle, + x_scale, + w_scale, + out, + ki.tile_m, + ki.tile_n, + ki.tile_k, + num_buffers=ki.num_buffers, + split_k=ki.split_k, + cluster_m=ki.cluster_m, + cluster_n=ki.cluster_n, + m_warp=ki.m_warp, + n_warp=ki.n_warp, + ) + return out + + def generate_data( m, n, k, seed, dtype=dtypes.bf16, q_dtype_w=dtypes.fp8, is_asm=False, device="cuda" ): @@ -464,6 +492,10 @@ def get_flydsl_gemm_a8w8_bpreshuffle_tune_task( seed, ): gfx, cu_num, M, N, K, q_dtype_w = info_keys + + if gfx == "gfx1250": + return self._get_flydsl_tune_task_gfx1250(info_keys, seed) + q_dtype_eval = eval(q_dtype_w) if q_dtype_eval == dtypes.fp8: pass @@ -541,6 +573,63 @@ def get_flydsl_gemm_a8w8_bpreshuffle_tune_task( ) return tasks + def _get_flydsl_tune_task_gfx1250(self, info_keys, seed): + """gfx1250 WMMA ptpc tuning tasks for the FlyDSL libtype.""" + gfx, cu_num, M, N, K, q_dtype_w = info_keys + if eval(q_dtype_w) != dtypes.fp8: + print( + f"[FlyDSL][gfx1250] WMMA ptpc supports fp8 only, skipping {q_dtype_w}" + ) + return [] + if not is_flydsl_available(): + return [] + try: + from aiter.ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( + kernels_list as kernels_list_flydsl_wmma, + kernel_fits_shape as kernel_fits_shape_wmma, + ) + except ImportError: + return [] + if not kernels_list_flydsl_wmma: + return [] + gemm_keys = ["x", "weight_shuffle", "x_scale", "w_scale", "out"] + ref_keys = ["x", "weight", "x_scale", "w_scale", "bias_f32"] + tasks = [] + for i in sorted(kernels_list_flydsl_wmma.keys()): + ki = kernels_list_flydsl_wmma[i] + if not kernel_fits_shape_wmma(ki, M, N, K): + continue + info = (info_keys, i, 0, ki.name, "flydsl") + tasks.append( + ( + info, + generate_data, + (M, N, K, seed, dtypes.bf16, dtypes.fp8), + run_gemm_flydsl_gfx1250, + ( + gemm_keys, + i, + ), + { + "num_warmup": args.warmup, + "num_iters": args.iters, + }, + run_torch, + ( + ref_keys, + dtypes.bf16, + ), + {}, + None, + 1e-2, + 0.01, + None, + None, + ("out",), + ) + ) + return tasks + def tune( self, untunedf, From 8e85a51644fa97dfd4dd368b4b04df8f5e4e5fcf Mon Sep 17 00:00:00 2001 From: aoli Date: Mon, 8 Jun 2026 06:31:12 +0000 Subject: [PATCH 12/19] test(gemm a8w8): route gfx1250 bpreshuffle through production dispatch --- op_tests/test_gemm_a8w8.py | 158 ++++++++- .../test_gemm_a8w8_bpreshuffle_gfx1250.py | 307 ------------------ 2 files changed, 154 insertions(+), 311 deletions(-) delete mode 100644 op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index 73cbd1659a..4f378684cb 100755 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -99,6 +99,41 @@ def run_gemm_ck_bpreshuffle(x, weight, x_scale, w_scale, dtype=dtypes.bf16): return aiter.gemm_a8w8_bpreshuffle(x, weight, x_scale, w_scale, None, dtype) +@perftest(num_iters=TEST_NUM_ITERS, num_rotate_args=1, use_cuda_event=True) +def run_gemm_flydsl_gfx1250( + x, + weightshuffle, + x_scale, + w_scale, + out, + tile, + num_buffers, + split_k, + m_warp, + n_warp, + cluster, +): + from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import ( + run_preshuffle_gemm_a8_gfx1250, + ) + + run_preshuffle_gemm_a8_gfx1250( + x, + weightshuffle, + x_scale, + w_scale, + out, + *tile, + num_buffers=num_buffers, + split_k=split_k, + m_warp=m_warp, + n_warp=n_warp, + cluster_m=cluster[0], + cluster_n=cluster[1], + ) + return out + + @perftest() def run_gemm_asm(x, weightshuffle, x_scale, w_scale, bias=None, dtype=dtypes.bf16): return aiter.gemm_a8w8_ASM(x, weightshuffle, x_scale, w_scale, bias) @@ -370,6 +405,115 @@ def test_normal_gemm_a8w8_pertoken_quant( return df +@benchmark() +def test_gemm_flydsl_feature( + dtype, + m, + n, + k, + tile_m=128, + tile_n=128, + tile_k=128, + num_buffers=4, + split_k=1, + m_warp=2, + n_warp=2, + cluster_m=1, + cluster_n=1, + variant="dense", +): + """One gfx1250 FlyDSL feature case (ragged M, strided A/C, split-k, or + warp/cluster) via the WMMA backend.""" + torch.manual_seed(0) + x = torch.randn((m, k), dtype=dtype, device="cuda") * 2.0 + weight = torch.randn((n, k), dtype=dtype, device="cuda") * 2.0 + x, x_scale = aiter.pertoken_quant(x, quant_dtype=dtypes.fp8) + weight, w_scale = aiter.pertoken_quant(weight, quant_dtype=dtypes.fp8) + weightshuffle = shuffle_weight(weight, layout=(16, 16)) + a, _ = run_torch(x, weight, x_scale, w_scale, None, dtype) + + xin = x + if variant == "strided_a": + big = torch.empty(m, k + 64, dtype=x.dtype, device="cuda") + big[:, :k].copy_(x) + xin = big[:, :k] + if variant == "strided_c": + out = torch.full((m, n + 64), -1.0, dtype=dtype, device="cuda")[:, :n] + else: + out = torch.empty(m, n, dtype=dtype, device="cuda") + + f, us = run_gemm_flydsl_gfx1250( + xin, + weightshuffle, + x_scale, + w_scale, + out, + (tile_m, tile_n, tile_k), + num_buffers, + split_k, + m_warp, + n_warp, + (cluster_m, cluster_n), + ) + + tol = 2e-2 if split_k > 1 else 1e-2 + err = float(checkAllclose(a, f, rtol=tol, atol=tol, printLog=False)) + bound = 0.10 if split_k > 1 else 0.05 + return { + "us": round(us, 2), + "tflops": round(2 * m * n * k / us / 1e6, 1), + "err": round(err, 4), + "pass": bool(err <= bound), + } + + +def test_gemm_flydsl_gfx1250_features(l_dtype): + """gfx1250 FlyDSL feature pass: cases the production dispatch can't reach + (ragged M, strided A/C, split-k, warp/cluster).""" + dt = l_dtype[0] + cases = [ + dict(m=17, n=256, k=1024, num_buffers=2, variant="dense"), + dict(m=100, n=256, k=1024, num_buffers=2, split_k=2, variant="dense"), + dict(m=257, n=256, k=1024, num_buffers=2, split_k=4, variant="dense"), + dict(m=128, n=256, k=1024, num_buffers=2, variant="strided_a"), + dict(m=128, n=256, k=1024, num_buffers=2, split_k=2, variant="strided_c"), + dict( + m=1, + n=256, + k=512, + tile_m=16, + tile_n=32, + tile_k=256, + num_buffers=2, + m_warp=1, + n_warp=2, + ), + dict( + m=1, + n=256, + k=512, + tile_m=32, + tile_n=64, + tile_k=256, + num_buffers=2, + m_warp=2, + n_warp=4, + ), + dict(m=512, n=512, k=512, num_buffers=2, cluster_m=2, cluster_n=2), + ] + rows = [] + for c in cases: + try: + rows.append(test_gemm_flydsl_feature(dt, **c)) + except Exception as e: # noqa: BLE001 + rows.append({"dtype": dt, **c, "pass": False, "note": str(e)[:50]}) + df = pd.DataFrame(rows) + aiter.logger.info( + "gfx1250 flydsl feature checks (markdown):\n%s", df.to_markdown(index=False) + ) + return df + + def test_skinny_gemm_a8w8_pertoken_quant(): # seed = 8779 # torch.manual_seed(seed) @@ -614,10 +758,16 @@ def _iter_flydsl_csv_cases(): ) ) - df = test_normal_gemm_a8w8_pertoken_quant( - args.dtype, args.quantDtype, args.mnk, args.pad_a - ) - test_skinny_gemm_a8w8_pertoken_quant() + if get_gfx() == "gfx1250": + df = test_normal_gemm_a8w8_pertoken_quant( + args.dtype, [dtypes.fp8], args.mnk, args.pad_a, skip_ck=True + ) + test_gemm_flydsl_gfx1250_features(args.dtype) + else: + df = test_normal_gemm_a8w8_pertoken_quant( + args.dtype, args.quantDtype, args.mnk, args.pad_a + ) + test_skinny_gemm_a8w8_pertoken_quant() if args.output and df is not None: os.makedirs(args.output, exist_ok=True) diff --git a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py b/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py deleted file mode 100644 index fae128b217..0000000000 --- a/op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py +++ /dev/null @@ -1,307 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -"""Correctness + benchmark for the gfx1250 (WMMA) a8w8 bpreshuffle GEMM. - -Main-driven like op_tests/test_gemm_a8w8.py (aiter runs op_tests via python3, -not pytest). Sweeps -mnk x -d, plus a feature pass (ragged M, strided A/C, -split-k, warp/cluster configs, vendored OOB descriptor). Each row times the -WMMA backend with run_perftest(use_cuda_event=True) -- the FlyDSL kernel is -JIT-dispatched (hipModuleLaunchKernel), which torch.profiler can miss -- and -checks correctness vs the dense reference (err / pass columns). Skipped off -gfx1250. -""" - -import argparse -import os -import sys - -import pandas as pd -import torch - -import aiter -from aiter.utility import dtypes -from aiter.test_common import checkAllclose, perftest, benchmark -from aiter.ops.shuffle import shuffle_weight -from aiter.jit.utils.chip_info import get_gfx_runtime as get_gfx -from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import run_preshuffle_gemm_a8_gfx1250 - -TEST_NUM_ITERS = 50 -_DTYPE = {"bf16": torch.bfloat16, "f16": torch.float16} - - -def _dt(dtype): - return "bf16" if dtype == torch.bfloat16 else "f16" - - -def _quant(m, n, k): - a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) * 2.0 - b = torch.randn(n, k, device="cuda", dtype=torch.bfloat16) * 2.0 - aq, a_scale = aiter.pertoken_quant(a, quant_dtype=dtypes.fp8) - bq, b_scale = aiter.pertoken_quant(b, quant_dtype=dtypes.fp8) - return aq, bq, a_scale, b_scale - - -def _ref(aq, bq, a_scale, b_scale, dtype): - a_f = aq.to(torch.float32) * a_scale.to(torch.float32) - b_f = bq.to(torch.float32) * b_scale.to(torch.float32) - return (a_f @ b_f.t()).to(dtype) - - -def _bound(split_k): - return 0.10 if split_k > 1 else 0.05 - - -@perftest(num_iters=TEST_NUM_ITERS, num_rotate_args=1, use_cuda_event=True) -def run_gemm_flydsl( - xq, - wq_sh, - x_scale, - w_scale, - out, - tile, - num_buffers, - split_k, - m_warp, - n_warp, - cluster, -): - run_preshuffle_gemm_a8_gfx1250( - xq, - wq_sh, - x_scale, - w_scale, - out, - *tile, - num_buffers=num_buffers, - split_k=split_k, - m_warp=m_warp, - n_warp=n_warp, - cluster_m=cluster[0], - cluster_n=cluster[1], - ) - return out - - -@benchmark() -def test_gemm( - dtype, - m, - n, - k, - tile_m=128, - tile_n=128, - tile_k=128, - num_buffers=4, - split_k=1, - m_warp=2, - n_warp=2, - cluster_m=1, - cluster_n=1, - variant="dense", -): - torch.manual_seed(0) - xq, wq, x_scale, w_scale = _quant(m, n, k) - wq_sh = shuffle_weight(wq, layout=(16, 16)) - ref = _ref(xq, wq, x_scale, w_scale, dtype) - - xin = xq - if variant == "strided_a": - big = torch.empty(m, k + 64, dtype=xq.dtype, device="cuda") - big[:, :k].copy_(xq) - xin = big[:, :k] - if variant == "strided_c": - out = torch.full((m, n + 64), -1.0, dtype=dtype, device="cuda")[:, :n] - else: - out = torch.empty(m, n, dtype=dtype, device="cuda") - - kmod = None - if variant == "oob": - import aiter.ops.flydsl.kernels.gemm_fp8fp4_gfx1250 as kmod - - saved, kmod._TDM_HAS_OOB = kmod._TDM_HAS_OOB, False - try: - out, us = run_gemm_flydsl( - xin, - wq_sh, - x_scale, - w_scale, - out, - (tile_m, tile_n, tile_k), - num_buffers, - split_k, - m_warp, - n_warp, - (cluster_m, cluster_n), - ) - finally: - if kmod is not None: - kmod._TDM_HAS_OOB = saved - - tol = 2e-2 if split_k > 1 else 1e-2 - err = float(checkAllclose(ref, out, rtol=tol, atol=tol, printLog=False)) - return { - "us": round(us, 2), - "tflops": round(2 * m * n * k / us / 1e6, 1), - "err": round(err, 4), - "pass": bool(err <= _bound(split_k)), - } - - -def _report(title, rows): - df = pd.DataFrame(rows) - aiter.logger.info("%s:\n%s", title, df.to_markdown(index=False)) - print(f"\n== {title} ==\n{df.to_markdown(index=False)}") - return df - - -def run_sweep(l_dtype, l_mnk, **cfg): - rows = [] - for dtype in l_dtype: - for m, n, k in l_mnk: - try: - rows.append(test_gemm(dtype, m, n, k, **cfg)) - except Exception as e: # noqa: BLE001 - rows.append( - { - "dtype": dtype, - "m": m, - "n": n, - "k": k, - "pass": False, - "note": str(e)[:50], - } - ) - return _report("shape sweep", rows) - - -def run_features(l_dtype): - dt = l_dtype[0] - cases = [ - dict(m=17, n=256, k=1024, num_buffers=2, variant="dense"), - dict(m=100, n=256, k=1024, num_buffers=2, split_k=2, variant="dense"), - dict(m=257, n=256, k=1024, num_buffers=2, split_k=4, variant="dense"), - dict(m=128, n=256, k=1024, num_buffers=2, variant="strided_a"), - dict(m=128, n=256, k=1024, num_buffers=2, split_k=2, variant="strided_c"), - dict(m=100, n=256, k=512, tile_m=64, tile_n=64, num_buffers=2, variant="oob"), - dict( - m=1, - n=256, - k=512, - tile_m=16, - tile_n=32, - tile_k=256, - num_buffers=2, - m_warp=1, - n_warp=2, - ), - dict( - m=1, - n=256, - k=512, - tile_m=32, - tile_n=64, - tile_k=256, - num_buffers=2, - m_warp=2, - n_warp=4, - ), - dict(m=512, n=512, k=512, num_buffers=2, cluster_m=2, cluster_n=2), - ] - rows = [] - for c in cases: - try: - rows.append(test_gemm(dt, **c)) - except Exception as e: # noqa: BLE001 - rows.append({"dtype": dt, **c, "pass": False, "note": str(e)[:50]}) - return _report("feature checks", rows) - - -_DEFAULT_MNK = [ - (1, 1280, 8192), - (32, 1280, 8192), - (64, 1280, 8192), - (128, 1280, 8192), - (192, 1280, 8192), - (256, 1280, 8192), - (320, 1280, 8192), - (512, 1280, 8192), - (1024, 1280, 8192), - (2048, 1280, 8192), - (4096, 1280, 8192), - (8192, 1280, 8192), - (16384, 1280, 8192), - (1, 8192, 1024), - (32, 8192, 1024), - (64, 8192, 1024), - (128, 8192, 1024), - (192, 8192, 1024), - (256, 8192, 1024), - (320, 8192, 1024), - (512, 8192, 1024), - (1024, 8192, 1024), - (2048, 8192, 1024), - (4096, 8192, 1024), - (8192, 8192, 1024), - (16384, 8192, 1024), - (16, 7424, 8192), - (32, 7424, 8192), - (48, 7424, 8192), - (64, 7424, 8192), - (4096, 7424, 8192), - (5120, 7424, 8192), - (8192, 7424, 8192), -] - -parser = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter, - description="gfx1250 FlyDSL PTPC FP8 a8w8 bpreshuffle GEMM correctness + benchmark.", -) -parser.add_argument( - "-d", "--dtype", nargs="*", choices=list(_DTYPE), default=list(_DTYPE) -) -parser.add_argument("-mnk", type=dtypes.str2tuple, nargs="*", default=_DEFAULT_MNK) -parser.add_argument("--tile_m", type=int, default=128) -parser.add_argument("--tile_n", type=int, default=128) -parser.add_argument("--tile_k", type=int, default=128) -parser.add_argument("--num_buffers", type=int, default=4) -parser.add_argument("--split_k", type=int, default=1) -parser.add_argument("--m_warp", type=int, default=2) -parser.add_argument("--n_warp", type=int, default=2) -parser.add_argument( - "--no-features", action="store_true", help="Skip the feature checks." -) -parser.add_argument("-o", "--output", type=str, default=None) -parser.add_argument("--suffix", type=str, default="results") -args = parser.parse_args() - -if not torch.cuda.is_available() or get_gfx() != "gfx1250": - print(f"Skipping: requires gfx1250 (current: {get_gfx()})") - sys.exit(0) - -l_dtype = [_DTYPE[d] for d in args.dtype] -df = run_sweep( - l_dtype, - args.mnk, - tile_m=args.tile_m, - tile_n=args.tile_n, - tile_k=args.tile_k, - num_buffers=args.num_buffers, - split_k=args.split_k, - m_warp=args.m_warp, - n_warp=args.n_warp, -) -dfs = [df] -if not args.no_features: - dfs.append(run_features(l_dtype)) - -if args.output: - os.makedirs(args.output, exist_ok=True) - out_path = os.path.join( - args.output, f"gemm_a8w8_bpreshuffle_gfx1250_{args.suffix}.csv" - ) - pd.concat(dfs, ignore_index=True).to_csv(out_path, index=False) - print(f"Saved results to: {out_path}") - -n_fail = sum(int((~d["pass"]).sum()) for d in dfs) -sys.exit(1 if n_fail else 0) From f89cacc0a16a51479da9400f2c6708353e75929c Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 8 Jun 2026 07:54:22 +0000 Subject: [PATCH 13/19] test(a8w8): remove torch random seed --- op_tests/test_gemm_a8w8.py | 1 - 1 file changed, 1 deletion(-) diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index 4f378684cb..3e68cbe0e4 100755 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -424,7 +424,6 @@ def test_gemm_flydsl_feature( ): """One gfx1250 FlyDSL feature case (ragged M, strided A/C, split-k, or warp/cluster) via the WMMA backend.""" - torch.manual_seed(0) x = torch.randn((m, k), dtype=dtype, device="cuda") * 2.0 weight = torch.randn((n, k), dtype=dtype, device="cuda") * 2.0 x, x_scale = aiter.pertoken_quant(x, quant_dtype=dtypes.fp8) From b21c3876bd0ee8847dface88a15aac2f203d1597 Mon Sep 17 00:00:00 2001 From: aoli Date: Mon, 8 Jun 2026 11:50:16 +0000 Subject: [PATCH 14/19] test(gemm a8w8): drop flydsl feature tests, fold gfx1250 skip logic into normal sweep --- op_tests/test_gemm_a8w8.py | 165 +++---------------------------------- 1 file changed, 13 insertions(+), 152 deletions(-) diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index 3e68cbe0e4..4106469a62 100755 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -99,41 +99,6 @@ def run_gemm_ck_bpreshuffle(x, weight, x_scale, w_scale, dtype=dtypes.bf16): return aiter.gemm_a8w8_bpreshuffle(x, weight, x_scale, w_scale, None, dtype) -@perftest(num_iters=TEST_NUM_ITERS, num_rotate_args=1, use_cuda_event=True) -def run_gemm_flydsl_gfx1250( - x, - weightshuffle, - x_scale, - w_scale, - out, - tile, - num_buffers, - split_k, - m_warp, - n_warp, - cluster, -): - from aiter.ops.flydsl.bpreshuffle_gemm_gfx1250 import ( - run_preshuffle_gemm_a8_gfx1250, - ) - - run_preshuffle_gemm_a8_gfx1250( - x, - weightshuffle, - x_scale, - w_scale, - out, - *tile, - num_buffers=num_buffers, - split_k=split_k, - m_warp=m_warp, - n_warp=n_warp, - cluster_m=cluster[0], - cluster_n=cluster[1], - ) - return out - - @perftest() def run_gemm_asm(x, weightshuffle, x_scale, w_scale, bias=None, dtype=dtypes.bf16): return aiter.gemm_a8w8_ASM(x, weightshuffle, x_scale, w_scale, bias) @@ -391,9 +356,18 @@ def calculate_total_valid_points(cu_count, aligned_k): def test_normal_gemm_a8w8_pertoken_quant( l_dtype, l_quantDtype, l_mnk, pad_a=128, skip_ck=False ): + is_gfx1250 = get_gfx() == "gfx1250" + if is_gfx1250 and not skip_ck: + aiter.logger.warning("gfx1250 has no CK a8w8 path; forcing skip_ck=True.") + skip_ck = True df = [] for dtype in l_dtype: for quantDtype in l_quantDtype: + if is_gfx1250 and quantDtype == dtypes.i8: + aiter.logger.warning( + "gfx1250 a8w8 only supports fp8 pertoken quant; skipping i8 shapes." + ) + continue for m, n, k in l_mnk: ret = test_gemm( dtype, m, n, k, quantDtype, pad_a=pad_a, skip_ck=skip_ck @@ -405,114 +379,6 @@ def test_normal_gemm_a8w8_pertoken_quant( return df -@benchmark() -def test_gemm_flydsl_feature( - dtype, - m, - n, - k, - tile_m=128, - tile_n=128, - tile_k=128, - num_buffers=4, - split_k=1, - m_warp=2, - n_warp=2, - cluster_m=1, - cluster_n=1, - variant="dense", -): - """One gfx1250 FlyDSL feature case (ragged M, strided A/C, split-k, or - warp/cluster) via the WMMA backend.""" - x = torch.randn((m, k), dtype=dtype, device="cuda") * 2.0 - weight = torch.randn((n, k), dtype=dtype, device="cuda") * 2.0 - x, x_scale = aiter.pertoken_quant(x, quant_dtype=dtypes.fp8) - weight, w_scale = aiter.pertoken_quant(weight, quant_dtype=dtypes.fp8) - weightshuffle = shuffle_weight(weight, layout=(16, 16)) - a, _ = run_torch(x, weight, x_scale, w_scale, None, dtype) - - xin = x - if variant == "strided_a": - big = torch.empty(m, k + 64, dtype=x.dtype, device="cuda") - big[:, :k].copy_(x) - xin = big[:, :k] - if variant == "strided_c": - out = torch.full((m, n + 64), -1.0, dtype=dtype, device="cuda")[:, :n] - else: - out = torch.empty(m, n, dtype=dtype, device="cuda") - - f, us = run_gemm_flydsl_gfx1250( - xin, - weightshuffle, - x_scale, - w_scale, - out, - (tile_m, tile_n, tile_k), - num_buffers, - split_k, - m_warp, - n_warp, - (cluster_m, cluster_n), - ) - - tol = 2e-2 if split_k > 1 else 1e-2 - err = float(checkAllclose(a, f, rtol=tol, atol=tol, printLog=False)) - bound = 0.10 if split_k > 1 else 0.05 - return { - "us": round(us, 2), - "tflops": round(2 * m * n * k / us / 1e6, 1), - "err": round(err, 4), - "pass": bool(err <= bound), - } - - -def test_gemm_flydsl_gfx1250_features(l_dtype): - """gfx1250 FlyDSL feature pass: cases the production dispatch can't reach - (ragged M, strided A/C, split-k, warp/cluster).""" - dt = l_dtype[0] - cases = [ - dict(m=17, n=256, k=1024, num_buffers=2, variant="dense"), - dict(m=100, n=256, k=1024, num_buffers=2, split_k=2, variant="dense"), - dict(m=257, n=256, k=1024, num_buffers=2, split_k=4, variant="dense"), - dict(m=128, n=256, k=1024, num_buffers=2, variant="strided_a"), - dict(m=128, n=256, k=1024, num_buffers=2, split_k=2, variant="strided_c"), - dict( - m=1, - n=256, - k=512, - tile_m=16, - tile_n=32, - tile_k=256, - num_buffers=2, - m_warp=1, - n_warp=2, - ), - dict( - m=1, - n=256, - k=512, - tile_m=32, - tile_n=64, - tile_k=256, - num_buffers=2, - m_warp=2, - n_warp=4, - ), - dict(m=512, n=512, k=512, num_buffers=2, cluster_m=2, cluster_n=2), - ] - rows = [] - for c in cases: - try: - rows.append(test_gemm_flydsl_feature(dt, **c)) - except Exception as e: # noqa: BLE001 - rows.append({"dtype": dt, **c, "pass": False, "note": str(e)[:50]}) - df = pd.DataFrame(rows) - aiter.logger.info( - "gfx1250 flydsl feature checks (markdown):\n%s", df.to_markdown(index=False) - ) - return df - - def test_skinny_gemm_a8w8_pertoken_quant(): # seed = 8779 # torch.manual_seed(seed) @@ -757,15 +623,10 @@ def _iter_flydsl_csv_cases(): ) ) - if get_gfx() == "gfx1250": - df = test_normal_gemm_a8w8_pertoken_quant( - args.dtype, [dtypes.fp8], args.mnk, args.pad_a, skip_ck=True - ) - test_gemm_flydsl_gfx1250_features(args.dtype) - else: - df = test_normal_gemm_a8w8_pertoken_quant( - args.dtype, args.quantDtype, args.mnk, args.pad_a - ) + df = test_normal_gemm_a8w8_pertoken_quant( + args.dtype, args.quantDtype, args.mnk, args.pad_a + ) + if get_gfx() != "gfx1250": test_skinny_gemm_a8w8_pertoken_quant() if args.output and df is not None: From ea7893bb071eb0f2e26d26c82c5e30061ecaa87a Mon Sep 17 00:00:00 2001 From: aoli Date: Mon, 8 Jun 2026 12:33:50 +0000 Subject: [PATCH 15/19] add gfx1250 gemm a8w8 tuned csv --- aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv | 75 ++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv index 1f9e989ca4..6cc2e64d78 100644 --- a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv @@ -549,6 +549,39 @@ gfx950,256,4096,57344,8192,torch.float8_e4m3fn,flydsl,979,0,1486.3101,flydsl_bpr gfx950,256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2638.66,506.16,0.0 gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2609.01,420.85,0.0 gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0 +gfx1250,256,1,1280,8192,torch.float8_e4m3fn,flydsl,9,0,7.7255,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,2.71,1358.68,0.0102 +gfx1250,256,32,1280,8192,torch.float8_e4m3fn,flydsl,9,0,13.2128,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,50.79,819.65,0.0114 +gfx1250,256,64,1280,8192,torch.float8_e4m3fn,flydsl,9,0,8.1176,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,165.34,1376.5,0.0109 +gfx1250,256,128,1280,8192,torch.float8_e4m3fn,flydsl,1,0,7.8052,flydsl_bpreshuffle_wmma_t16x64x512_mw1_nw2_nb4_sk1_cm1_cn1,343.92,1519.76,0.011 +gfx1250,256,192,1280,8192,torch.float8_e4m3fn,flydsl,1,0,8.2986,flydsl_bpreshuffle_wmma_t16x64x512_mw1_nw2_nb4_sk1_cm1_cn1,485.21,1512.32,0.0112 +gfx1250,256,256,1280,8192,torch.float8_e4m3fn,flydsl,9,0,8.5726,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,626.26,1544.25,0.0111 +gfx1250,256,320,1280,8192,torch.float8_e4m3fn,flydsl,9,0,8.8099,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,761.74,1580.77,0.0113 +gfx1250,256,512,1280,8192,torch.float8_e4m3fn,flydsl,8,0,30.67,flydsl_bpreshuffle_wmma_t32x64x256_mw2_nw2_nb4_sk1_cm1_cn1,350.1,521.38,0.0113 +gfx1250,256,1024,1280,8192,torch.float8_e4m3fn,flydsl,10,0,29.4432,flydsl_bpreshuffle_wmma_t32x128x256_mw2_nw2_nb4_sk1_cm1_cn1,729.36,730.08,0.0113 +gfx1250,256,2048,1280,8192,torch.float8_e4m3fn,flydsl,13,0,31.1003,flydsl_bpreshuffle_wmma_t64x64x256_mw2_nw2_nb4_sk1_cm1_cn1,1381.01,1045.19,0.0114 +gfx1250,256,4096,1280,8192,torch.float8_e4m3fn,flydsl,23,0,45.4271,flydsl_bpreshuffle_wmma_t256x128x128_mw2_nw2_nb4_sk1_cm1_cn1,1890.93,1200.3,0.0113 +gfx1250,256,8192,1280,8192,torch.float8_e4m3fn,flydsl,25,0,54.0233,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3180.09,1824.51,0.0114 +gfx1250,256,16384,1280,8192,torch.float8_e4m3fn,flydsl,25,0,91.547,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3753.23,2038.81,0.0113 +gfx1250,256,1,8192,1024,torch.float8_e4m3fn,flydsl,18,0,10.5165,flydsl_bpreshuffle_wmma_t128x128x128_mw2_nw2_nb4_sk1_cm1_cn1,1.6,799.32,0.0 +gfx1250,256,32,8192,1024,torch.float8_e4m3fn,flydsl,12,0,12.2416,flydsl_bpreshuffle_wmma_t32x256x256_mw2_nw4_nb4_sk1_cm1_cn1,43.86,730.76,0.0 +gfx1250,256,64,8192,1024,torch.float8_e4m3fn,flydsl,3,0,14.0891,flydsl_bpreshuffle_wmma_t16x128x256_mw1_nw2_nb4_sk1_cm1_cn1,76.21,674.47,0.0 +gfx1250,256,128,8192,1024,torch.float8_e4m3fn,flydsl,6,0,12.595,flydsl_bpreshuffle_wmma_t16x256x256_mw1_nw4_nb4_sk1_cm1_cn1,170.5,842.94,0.0 +gfx1250,256,192,8192,1024,torch.float8_e4m3fn,flydsl,8,0,19.2571,flydsl_bpreshuffle_wmma_t32x64x256_mw2_nw2_nb4_sk1_cm1_cn1,167.27,609.18,0.0 +gfx1250,256,256,8192,1024,torch.float8_e4m3fn,flydsl,7,0,14.4043,flydsl_bpreshuffle_wmma_t16x512x128_mw1_nw4_nb4_sk1_cm1_cn1,298.17,891.75,0.0 +gfx1250,256,320,8192,1024,torch.float8_e4m3fn,flydsl,18,0,19.8423,flydsl_bpreshuffle_wmma_t128x128x128_mw2_nw2_nb4_sk1_cm1_cn1,270.57,703.51,0.0 +gfx1250,256,512,8192,1024,torch.float8_e4m3fn,flydsl,20,0,23.9115,flydsl_bpreshuffle_wmma_t128x256x128_mw2_nw4_nb4_sk1_cm1_cn1,359.24,723.56,0.0 +gfx1250,256,1024,8192,1024,torch.float8_e4m3fn,flydsl,20,0,28.1436,flydsl_bpreshuffle_wmma_t128x256x128_mw2_nw4_nb4_sk1_cm1_cn1,610.44,931.45,0.0 +gfx1250,256,2048,8192,1024,torch.float8_e4m3fn,flydsl,17,0,35.2259,flydsl_bpreshuffle_wmma_t64x512x128_mw1_nw4_nb3_sk1_cm1_cn1,975.41,1250.22,0.0 +gfx1250,256,4096,8192,1024,torch.float8_e4m3fn,flydsl,25,0,38.5197,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1784.01,2068.86,0.0 +gfx1250,256,8192,8192,1024,torch.float8_e4m3fn,flydsl,25,0,51.4394,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2671.86,2935.39,0.0 +gfx1250,256,16384,8192,1024,torch.float8_e4m3fn,flydsl,25,0,88.565,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3103.69,3315.09,0.0 +gfx1250,256,16,7424,8192,torch.float8_e4m3fn,flydsl,9,0,28.7155,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,67.77,2130.77,0.011 +gfx1250,256,32,7424,8192,torch.float8_e4m3fn,flydsl,3,0,30.5471,flydsl_bpreshuffle_wmma_t16x128x256_mw1_nw2_nb4_sk1_cm1_cn1,127.42,2015.07,0.0115 +gfx1250,256,48,7424,8192,torch.float8_e4m3fn,flydsl,3,0,31.5982,flydsl_bpreshuffle_wmma_t16x128x256_mw1_nw2_nb4_sk1_cm1_cn1,184.77,1959.71,0.0116 +gfx1250,256,64,7424,8192,torch.float8_e4m3fn,flydsl,9,0,35.7317,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,217.86,1743.33,0.0115 +gfx1250,256,4096,7424,8192,torch.float8_e4m3fn,flydsl,25,0,132.4934,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3760.31,1171.3,0.0113 +gfx1250,256,5120,7424,8192,torch.float8_e4m3fn,flydsl,25,0,146.392,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,4254.13,1221.26,0.0113 +gfx1250,256,8192,7424,8192,torch.float8_e4m3fn,flydsl,25,0,203.6344,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,4893.24,1225.54,0.0113 gfx1250,256,1,2112,7168,torch.float8_e4m3fn,flydsl,261,0,8.8799,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,3.41,1706.12,0.0066 gfx1250,256,2,2112,7168,torch.float8_e4m3fn,flydsl,381,0,8.2149,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk4_cm1_cn1,7.37,1845.62,0.0078 gfx1250,256,4,2112,7168,torch.float8_e4m3fn,flydsl,294,0,28.6242,flydsl_bpreshuffle_wmma_t64x64x256_mw2_nw2_nb4_sk1_cm1_cn1,4.23,530.47,0.0099 @@ -556,6 +589,26 @@ gfx1250,256,8,2112,7168,torch.float8_e4m3fn,flydsl,303,0,27.762,flydsl_bpreshuff gfx1250,256,16,2112,7168,torch.float8_e4m3fn,flydsl,25,0,30.8522,flydsl_bpreshuffle_wmma_t32x32x256_mw2_nw2_nb3_sk1_cm1_cn1,15.7,496.6,0.011 gfx1250,256,32,2112,7168,torch.float8_e4m3fn,flydsl,286,0,28.9229,flydsl_bpreshuffle_wmma_t32x64x256_mw2_nw2_nb4_sk1_cm1_cn1,33.5,536.02,0.0105 gfx1250,256,64,2112,7168,torch.float8_e4m3fn,flydsl,269,0,31.7564,flydsl_bpreshuffle_wmma_t32x32x256_mw1_nw2_nb4_sk1_cm1_cn1,61.02,499.68,0.0098 +gfx1250,256,128,2112,7168,torch.float8_e4m3fn,flydsl,9,0,8.326,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,465.47,1993.39,0.0096 +gfx1250,256,256,2112,7168,torch.float8_e4m3fn,flydsl,19,0,25.7514,flydsl_bpreshuffle_wmma_t128x192x128_mw2_nw2_nb4_sk1_cm1_cn1,301.0,701.13,0.0098 +gfx1250,256,512,2112,7168,torch.float8_e4m3fn,flydsl,398,0,40.0855,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,386.73,523.17,0.0098 +gfx1250,256,1024,2112,7168,torch.float8_e4m3fn,flydsl,398,0,68.7538,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,450.95,389.86,0.0096 +gfx1250,256,2048,2112,7168,torch.float8_e4m3fn,flydsl,398,0,126.0903,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,491.78,305.1,0.0098 +gfx1250,256,4096,2112,7168,torch.float8_e4m3fn,flydsl,398,0,240.7634,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,515.1,256.69,0.0098 +gfx1250,256,8192,2112,7168,torch.float8_e4m3fn,flydsl,398,0,470.1095,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,527.61,230.72,0.0098 +gfx1250,256,16384,2112,7168,torch.float8_e4m3fn,flydsl,398,0,928.8018,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,534.1,217.25,0.0098 +gfx1250,256,32768,2112,7168,torch.float8_e4m3fn,flydsl,398,0,1846.1863,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,537.4,210.4,0.0098 +gfx1250,256,65536,2112,7168,torch.float8_e4m3fn,flydsl,398,0,3680.9554,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,539.07,206.94,0.0098 +gfx1250,256,1,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.7073,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,7.41,3704.86,0.0237 +gfx1250,256,2,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.7156,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,14.81,3704.86,0.0237 +gfx1250,256,4,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.7322,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,29.61,3704.86,0.0237 +gfx1250,256,8,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.7654,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,59.15,3704.86,0.0237 +gfx1250,256,16,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.8317,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,118.06,3704.86,0.0237 +gfx1250,256,32,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.9644,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,235.14,3704.86,0.0237 +gfx1250,256,64,7168,16384,torch.float8_e4m3fn,flydsl,500,0,32.2297,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,466.41,3704.86,0.0237 +gfx1250,256,128,7168,16384,torch.float8_e4m3fn,flydsl,500,0,32.7604,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,917.72,3704.86,0.0237 +gfx1250,256,256,7168,16384,torch.float8_e4m3fn,flydsl,500,0,33.8217,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,1777.84,3704.86,0.0237 +gfx1250,256,512,7168,16384,torch.float8_e4m3fn,flydsl,500,0,51.8551,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,2319.14,2568.1,0.0237 gfx1250,256,1024,7168,16384,torch.float8_e4m3fn,flydsl,500,0,92.0588,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,2612.66,1617.42,0.0237 gfx1250,256,2048,7168,16384,torch.float8_e4m3fn,flydsl,302,0,172.4662,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2789.16,1045.74,0.0237 gfx1250,256,4096,7168,16384,torch.float8_e4m3fn,flydsl,302,0,327.0773,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2941.42,743.77,0.0236 @@ -564,6 +617,16 @@ gfx1250,256,8192,7168,16384,torch.float8_e4m3fn,flydsl,302,0,640.4273,flydsl_bpr gfx1250,256,10240,7168,16384,torch.float8_e4m3fn,flydsl,302,0,806.3905,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2982.65,535.74,0.0237 gfx1250,256,20480,7168,16384,torch.float8_e4m3fn,flydsl,302,0,1652.4551,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2911.04,451.8,0.0237 gfx1250,256,65536,7168,16384,torch.float8_e4m3fn,flydsl,302,0,5987.8537,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2570.73,355.84,0.0237 +gfx1250,256,1,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.4757,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1.96,982.42,0.0 +gfx1250,256,2,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.5011,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3.92,983.09,0.0 +gfx1250,256,4,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.5519,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,7.83,984.42,0.0 +gfx1250,256,8,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.6536,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,15.63,987.08,0.0 +gfx1250,256,16,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.8569,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,31.09,992.35,0.0 +gfx1250,256,32,24576,1536,torch.float8_e4m3fn,flydsl,302,0,39.2634,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,61.53,1002.73,0.0 +gfx1250,256,64,24576,1536,torch.float8_e4m3fn,flydsl,302,0,40.0765,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,120.57,1022.86,0.0 +gfx1250,256,128,24576,1536,torch.float8_e4m3fn,flydsl,302,0,41.7027,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,231.73,1060.76,0.0 +gfx1250,256,256,24576,1536,torch.float8_e4m3fn,flydsl,302,0,44.9552,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,429.92,1128.34,0.0 +gfx1250,256,512,24576,1536,torch.float8_e4m3fn,flydsl,302,0,51.4601,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,751.16,1237.87,0.0 gfx1250,256,1024,24576,1536,torch.float8_e4m3fn,flydsl,302,0,64.4699,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1199.16,1390.62,0.0 gfx1250,256,2048,24576,1536,torch.float8_e4m3fn,flydsl,302,0,90.4895,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1708.69,1564.36,0.0 gfx1250,256,4096,24576,1536,torch.float8_e4m3fn,flydsl,302,0,146.1362,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2116.09,1679.03,0.0 @@ -572,6 +635,16 @@ gfx1250,256,8192,24576,1536,torch.float8_e4m3fn,flydsl,302,0,257.4906,flydsl_bpr gfx1250,256,10240,24576,1536,torch.float8_e4m3fn,flydsl,302,0,316.7838,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2440.45,1757.65,0.0 gfx1250,256,20480,24576,1536,torch.float8_e4m3fn,flydsl,302,0,607.1774,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2546.52,1771.87,0.0 gfx1250,256,65536,24576,1536,torch.float8_e4m3fn,flydsl,302,0,2046.0765,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2418.19,1641.99,0.0 +gfx1250,256,1,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5109,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,0.72,362.14,0.0 +gfx1250,256,2,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.515,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1.44,363.52,0.0 +gfx1250,256,4,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5232,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2.88,366.3,0.0 +gfx1250,256,8,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5396,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,5.77,371.85,0.0 +gfx1250,256,16,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5724,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,11.53,382.93,0.0 +gfx1250,256,32,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.638,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,23.02,405.05,0.0 +gfx1250,256,64,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.7691,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,45.92,449.11,0.0 +gfx1250,256,128,32768,512,torch.float8_e4m3fn,flydsl,302,0,47.0315,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,91.32,536.48,0.0 +gfx1250,256,256,32768,512,torch.float8_e4m3fn,flydsl,302,0,47.5561,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,180.63,708.33,0.0 +gfx1250,256,512,32768,512,torch.float8_e4m3fn,flydsl,302,0,48.6054,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,353.46,1040.91,0.0 gfx1250,256,1024,32768,512,torch.float8_e4m3fn,flydsl,302,0,50.7041,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,677.65,1664.76,0.0 gfx1250,256,2048,32768,512,torch.float8_e4m3fn,flydsl,302,0,54.9014,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1251.69,2769.39,0.0 gfx1250,256,4096,32768,512,torch.float8_e4m3fn,flydsl,302,0,88.9453,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1545.21,3230.19,0.0 @@ -579,4 +652,4 @@ gfx1250,256,6144,32768,512,torch.float8_e4m3fn,flydsl,302,0,119.3873,flydsl_bpre gfx1250,256,8192,32768,512,torch.float8_e4m3fn,flydsl,302,0,150.5706,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1825.57,3704.86,0.0 gfx1250,256,10240,32768,512,torch.float8_e4m3fn,flydsl,302,0,187.2761,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1834.71,3701.0,0.0 gfx1250,256,20480,32768,512,torch.float8_e4m3fn,flydsl,302,0,393.5255,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1746.25,3479.93,0.0 -gfx1250,256,1024,2112,7168,torch.float8_e4m3fn,flydsl,398,0,68.7538,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,450.95,389.86,0.0096 +gfx1250,256,65536,32768,512,torch.float8_e4m3fn,flydsl,21,0,846.9039,flydsl_bpreshuffle_wmma_t128x512x128_mw2_nw4_nb3_sk1_cm1_cn1,2596.54,5130.81,0.0 From d5c0eeb7356db42c4a2939a5033f8ec7f2fcee1f Mon Sep 17 00:00:00 2001 From: aoli Date: Mon, 8 Jun 2026 14:29:45 +0000 Subject: [PATCH 16/19] remove dup tuned shapes --- aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv | 64 ------------------- 1 file changed, 64 deletions(-) diff --git a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv index 6cc2e64d78..7ba0373836 100644 --- a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv @@ -549,43 +549,6 @@ gfx950,256,4096,57344,8192,torch.float8_e4m3fn,flydsl,979,0,1486.3101,flydsl_bpr gfx950,256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2638.66,506.16,0.0 gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2609.01,420.85,0.0 gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0 -gfx1250,256,1,1280,8192,torch.float8_e4m3fn,flydsl,9,0,7.7255,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,2.71,1358.68,0.0102 -gfx1250,256,32,1280,8192,torch.float8_e4m3fn,flydsl,9,0,13.2128,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,50.79,819.65,0.0114 -gfx1250,256,64,1280,8192,torch.float8_e4m3fn,flydsl,9,0,8.1176,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,165.34,1376.5,0.0109 -gfx1250,256,128,1280,8192,torch.float8_e4m3fn,flydsl,1,0,7.8052,flydsl_bpreshuffle_wmma_t16x64x512_mw1_nw2_nb4_sk1_cm1_cn1,343.92,1519.76,0.011 -gfx1250,256,192,1280,8192,torch.float8_e4m3fn,flydsl,1,0,8.2986,flydsl_bpreshuffle_wmma_t16x64x512_mw1_nw2_nb4_sk1_cm1_cn1,485.21,1512.32,0.0112 -gfx1250,256,256,1280,8192,torch.float8_e4m3fn,flydsl,9,0,8.5726,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,626.26,1544.25,0.0111 -gfx1250,256,320,1280,8192,torch.float8_e4m3fn,flydsl,9,0,8.8099,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,761.74,1580.77,0.0113 -gfx1250,256,512,1280,8192,torch.float8_e4m3fn,flydsl,8,0,30.67,flydsl_bpreshuffle_wmma_t32x64x256_mw2_nw2_nb4_sk1_cm1_cn1,350.1,521.38,0.0113 -gfx1250,256,1024,1280,8192,torch.float8_e4m3fn,flydsl,10,0,29.4432,flydsl_bpreshuffle_wmma_t32x128x256_mw2_nw2_nb4_sk1_cm1_cn1,729.36,730.08,0.0113 -gfx1250,256,2048,1280,8192,torch.float8_e4m3fn,flydsl,13,0,31.1003,flydsl_bpreshuffle_wmma_t64x64x256_mw2_nw2_nb4_sk1_cm1_cn1,1381.01,1045.19,0.0114 -gfx1250,256,4096,1280,8192,torch.float8_e4m3fn,flydsl,23,0,45.4271,flydsl_bpreshuffle_wmma_t256x128x128_mw2_nw2_nb4_sk1_cm1_cn1,1890.93,1200.3,0.0113 -gfx1250,256,8192,1280,8192,torch.float8_e4m3fn,flydsl,25,0,54.0233,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3180.09,1824.51,0.0114 -gfx1250,256,16384,1280,8192,torch.float8_e4m3fn,flydsl,25,0,91.547,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3753.23,2038.81,0.0113 -gfx1250,256,1,8192,1024,torch.float8_e4m3fn,flydsl,18,0,10.5165,flydsl_bpreshuffle_wmma_t128x128x128_mw2_nw2_nb4_sk1_cm1_cn1,1.6,799.32,0.0 -gfx1250,256,32,8192,1024,torch.float8_e4m3fn,flydsl,12,0,12.2416,flydsl_bpreshuffle_wmma_t32x256x256_mw2_nw4_nb4_sk1_cm1_cn1,43.86,730.76,0.0 -gfx1250,256,64,8192,1024,torch.float8_e4m3fn,flydsl,3,0,14.0891,flydsl_bpreshuffle_wmma_t16x128x256_mw1_nw2_nb4_sk1_cm1_cn1,76.21,674.47,0.0 -gfx1250,256,128,8192,1024,torch.float8_e4m3fn,flydsl,6,0,12.595,flydsl_bpreshuffle_wmma_t16x256x256_mw1_nw4_nb4_sk1_cm1_cn1,170.5,842.94,0.0 -gfx1250,256,192,8192,1024,torch.float8_e4m3fn,flydsl,8,0,19.2571,flydsl_bpreshuffle_wmma_t32x64x256_mw2_nw2_nb4_sk1_cm1_cn1,167.27,609.18,0.0 -gfx1250,256,256,8192,1024,torch.float8_e4m3fn,flydsl,7,0,14.4043,flydsl_bpreshuffle_wmma_t16x512x128_mw1_nw4_nb4_sk1_cm1_cn1,298.17,891.75,0.0 -gfx1250,256,320,8192,1024,torch.float8_e4m3fn,flydsl,18,0,19.8423,flydsl_bpreshuffle_wmma_t128x128x128_mw2_nw2_nb4_sk1_cm1_cn1,270.57,703.51,0.0 -gfx1250,256,512,8192,1024,torch.float8_e4m3fn,flydsl,20,0,23.9115,flydsl_bpreshuffle_wmma_t128x256x128_mw2_nw4_nb4_sk1_cm1_cn1,359.24,723.56,0.0 -gfx1250,256,1024,8192,1024,torch.float8_e4m3fn,flydsl,20,0,28.1436,flydsl_bpreshuffle_wmma_t128x256x128_mw2_nw4_nb4_sk1_cm1_cn1,610.44,931.45,0.0 -gfx1250,256,2048,8192,1024,torch.float8_e4m3fn,flydsl,17,0,35.2259,flydsl_bpreshuffle_wmma_t64x512x128_mw1_nw4_nb3_sk1_cm1_cn1,975.41,1250.22,0.0 -gfx1250,256,4096,8192,1024,torch.float8_e4m3fn,flydsl,25,0,38.5197,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1784.01,2068.86,0.0 -gfx1250,256,8192,8192,1024,torch.float8_e4m3fn,flydsl,25,0,51.4394,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2671.86,2935.39,0.0 -gfx1250,256,16384,8192,1024,torch.float8_e4m3fn,flydsl,25,0,88.565,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3103.69,3315.09,0.0 -gfx1250,256,16,7424,8192,torch.float8_e4m3fn,flydsl,9,0,28.7155,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,67.77,2130.77,0.011 -gfx1250,256,32,7424,8192,torch.float8_e4m3fn,flydsl,3,0,30.5471,flydsl_bpreshuffle_wmma_t16x128x256_mw1_nw2_nb4_sk1_cm1_cn1,127.42,2015.07,0.0115 -gfx1250,256,48,7424,8192,torch.float8_e4m3fn,flydsl,3,0,31.5982,flydsl_bpreshuffle_wmma_t16x128x256_mw1_nw2_nb4_sk1_cm1_cn1,184.77,1959.71,0.0116 -gfx1250,256,64,7424,8192,torch.float8_e4m3fn,flydsl,9,0,35.7317,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,217.86,1743.33,0.0115 -gfx1250,256,4096,7424,8192,torch.float8_e4m3fn,flydsl,25,0,132.4934,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3760.31,1171.3,0.0113 -gfx1250,256,5120,7424,8192,torch.float8_e4m3fn,flydsl,25,0,146.392,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,4254.13,1221.26,0.0113 -gfx1250,256,8192,7424,8192,torch.float8_e4m3fn,flydsl,25,0,203.6344,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,4893.24,1225.54,0.0113 -gfx1250,256,1,2112,7168,torch.float8_e4m3fn,flydsl,261,0,8.8799,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,3.41,1706.12,0.0066 -gfx1250,256,2,2112,7168,torch.float8_e4m3fn,flydsl,381,0,8.2149,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk4_cm1_cn1,7.37,1845.62,0.0078 -gfx1250,256,4,2112,7168,torch.float8_e4m3fn,flydsl,294,0,28.6242,flydsl_bpreshuffle_wmma_t64x64x256_mw2_nw2_nb4_sk1_cm1_cn1,4.23,530.47,0.0099 -gfx1250,256,8,2112,7168,torch.float8_e4m3fn,flydsl,303,0,27.762,flydsl_bpreshuffle_wmma_t16x64x128_mw1_nw4_nb4_sk1_cm1_cn1,8.72,548.59,0.0102 gfx1250,256,16,2112,7168,torch.float8_e4m3fn,flydsl,25,0,30.8522,flydsl_bpreshuffle_wmma_t32x32x256_mw2_nw2_nb3_sk1_cm1_cn1,15.7,496.6,0.011 gfx1250,256,32,2112,7168,torch.float8_e4m3fn,flydsl,286,0,28.9229,flydsl_bpreshuffle_wmma_t32x64x256_mw2_nw2_nb4_sk1_cm1_cn1,33.5,536.02,0.0105 gfx1250,256,64,2112,7168,torch.float8_e4m3fn,flydsl,269,0,31.7564,flydsl_bpreshuffle_wmma_t32x32x256_mw1_nw2_nb4_sk1_cm1_cn1,61.02,499.68,0.0098 @@ -596,13 +559,6 @@ gfx1250,256,1024,2112,7168,torch.float8_e4m3fn,flydsl,398,0,68.7538,flydsl_bpres gfx1250,256,2048,2112,7168,torch.float8_e4m3fn,flydsl,398,0,126.0903,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,491.78,305.1,0.0098 gfx1250,256,4096,2112,7168,torch.float8_e4m3fn,flydsl,398,0,240.7634,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,515.1,256.69,0.0098 gfx1250,256,8192,2112,7168,torch.float8_e4m3fn,flydsl,398,0,470.1095,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,527.61,230.72,0.0098 -gfx1250,256,16384,2112,7168,torch.float8_e4m3fn,flydsl,398,0,928.8018,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,534.1,217.25,0.0098 -gfx1250,256,32768,2112,7168,torch.float8_e4m3fn,flydsl,398,0,1846.1863,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,537.4,210.4,0.0098 -gfx1250,256,65536,2112,7168,torch.float8_e4m3fn,flydsl,398,0,3680.9554,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,539.07,206.94,0.0098 -gfx1250,256,1,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.7073,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,7.41,3704.86,0.0237 -gfx1250,256,2,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.7156,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,14.81,3704.86,0.0237 -gfx1250,256,4,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.7322,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,29.61,3704.86,0.0237 -gfx1250,256,8,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.7654,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,59.15,3704.86,0.0237 gfx1250,256,16,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.8317,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,118.06,3704.86,0.0237 gfx1250,256,32,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.9644,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,235.14,3704.86,0.0237 gfx1250,256,64,7168,16384,torch.float8_e4m3fn,flydsl,500,0,32.2297,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,466.41,3704.86,0.0237 @@ -612,15 +568,7 @@ gfx1250,256,512,7168,16384,torch.float8_e4m3fn,flydsl,500,0,51.8551,flydsl_bpres gfx1250,256,1024,7168,16384,torch.float8_e4m3fn,flydsl,500,0,92.0588,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,2612.66,1617.42,0.0237 gfx1250,256,2048,7168,16384,torch.float8_e4m3fn,flydsl,302,0,172.4662,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2789.16,1045.74,0.0237 gfx1250,256,4096,7168,16384,torch.float8_e4m3fn,flydsl,302,0,327.0773,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2941.42,743.77,0.0236 -gfx1250,256,6144,7168,16384,torch.float8_e4m3fn,flydsl,302,0,475.4687,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3035.13,643.96,0.0237 gfx1250,256,8192,7168,16384,torch.float8_e4m3fn,flydsl,302,0,640.4273,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3004.47,576.33,0.0237 -gfx1250,256,10240,7168,16384,torch.float8_e4m3fn,flydsl,302,0,806.3905,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2982.65,535.74,0.0237 -gfx1250,256,20480,7168,16384,torch.float8_e4m3fn,flydsl,302,0,1652.4551,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2911.04,451.8,0.0237 -gfx1250,256,65536,7168,16384,torch.float8_e4m3fn,flydsl,302,0,5987.8537,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2570.73,355.84,0.0237 -gfx1250,256,1,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.4757,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1.96,982.42,0.0 -gfx1250,256,2,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.5011,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3.92,983.09,0.0 -gfx1250,256,4,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.5519,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,7.83,984.42,0.0 -gfx1250,256,8,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.6536,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,15.63,987.08,0.0 gfx1250,256,16,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.8569,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,31.09,992.35,0.0 gfx1250,256,32,24576,1536,torch.float8_e4m3fn,flydsl,302,0,39.2634,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,61.53,1002.73,0.0 gfx1250,256,64,24576,1536,torch.float8_e4m3fn,flydsl,302,0,40.0765,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,120.57,1022.86,0.0 @@ -630,15 +578,7 @@ gfx1250,256,512,24576,1536,torch.float8_e4m3fn,flydsl,302,0,51.4601,flydsl_bpres gfx1250,256,1024,24576,1536,torch.float8_e4m3fn,flydsl,302,0,64.4699,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1199.16,1390.62,0.0 gfx1250,256,2048,24576,1536,torch.float8_e4m3fn,flydsl,302,0,90.4895,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1708.69,1564.36,0.0 gfx1250,256,4096,24576,1536,torch.float8_e4m3fn,flydsl,302,0,146.1362,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2116.09,1679.03,0.0 -gfx1250,256,6144,24576,1536,torch.float8_e4m3fn,flydsl,302,0,197.9044,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2343.84,1764.37,0.0 gfx1250,256,8192,24576,1536,torch.float8_e4m3fn,flydsl,302,0,257.4906,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2401.93,1759.23,0.0 -gfx1250,256,10240,24576,1536,torch.float8_e4m3fn,flydsl,302,0,316.7838,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2440.45,1757.65,0.0 -gfx1250,256,20480,24576,1536,torch.float8_e4m3fn,flydsl,302,0,607.1774,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2546.52,1771.87,0.0 -gfx1250,256,65536,24576,1536,torch.float8_e4m3fn,flydsl,302,0,2046.0765,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2418.19,1641.99,0.0 -gfx1250,256,1,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5109,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,0.72,362.14,0.0 -gfx1250,256,2,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.515,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1.44,363.52,0.0 -gfx1250,256,4,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5232,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2.88,366.3,0.0 -gfx1250,256,8,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5396,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,5.77,371.85,0.0 gfx1250,256,16,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5724,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,11.53,382.93,0.0 gfx1250,256,32,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.638,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,23.02,405.05,0.0 gfx1250,256,64,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.7691,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,45.92,449.11,0.0 @@ -648,8 +588,4 @@ gfx1250,256,512,32768,512,torch.float8_e4m3fn,flydsl,302,0,48.6054,flydsl_bpresh gfx1250,256,1024,32768,512,torch.float8_e4m3fn,flydsl,302,0,50.7041,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,677.65,1664.76,0.0 gfx1250,256,2048,32768,512,torch.float8_e4m3fn,flydsl,302,0,54.9014,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1251.69,2769.39,0.0 gfx1250,256,4096,32768,512,torch.float8_e4m3fn,flydsl,302,0,88.9453,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1545.21,3230.19,0.0 -gfx1250,256,6144,32768,512,torch.float8_e4m3fn,flydsl,302,0,119.3873,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1726.8,3539.54,0.0 gfx1250,256,8192,32768,512,torch.float8_e4m3fn,flydsl,302,0,150.5706,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1825.57,3704.86,0.0 -gfx1250,256,10240,32768,512,torch.float8_e4m3fn,flydsl,302,0,187.2761,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1834.71,3701.0,0.0 -gfx1250,256,20480,32768,512,torch.float8_e4m3fn,flydsl,302,0,393.5255,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1746.25,3479.93,0.0 -gfx1250,256,65536,32768,512,torch.float8_e4m3fn,flydsl,21,0,846.9039,flydsl_bpreshuffle_wmma_t128x512x128_mw2_nw4_nb3_sk1_cm1_cn1,2596.54,5130.81,0.0 From 64e6ef3790d502af3216dedbfcde435326bd3b8b Mon Sep 17 00:00:00 2001 From: aoli Date: Mon, 8 Jun 2026 14:38:07 +0000 Subject: [PATCH 17/19] gfx1250 gemm a8w8 untuned shape fallback to flydsl --- aiter/ops/gemm_op_a8w8.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 9525a64df0..2d43f58f75 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -662,6 +662,26 @@ def gemm_a8w8_bpreshuffle( return gemm_a8w8_bpreshuffle_cktile(XQ, WQ, x_scale, w_scale, Y, splitK) elif libtype == "flydsl" and is_flydsl_available(): return gemm_a8w8_bpreshuffle_flydsl(XQ, WQ, x_scale, w_scale, Y, config) + + if get_gfx() == "gfx1250" and is_flydsl_available(): + from ..ops.flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_wmma_common import ( + kernel_fits_shape, + kernels_list, + ) + + fits = [ki for ki in kernels_list.values() if kernel_fits_shape(ki, m, n, k)] + if fits: + want_tm = min(256, max(16, 1 << (m - 1).bit_length())) + ki = min( + fits, key=lambda x: (abs(x.tile_m - want_tm), -x.tile_n, -x.tile_k) + ) + logger.warning( + f"[gfx1250] gemm_a8w8_bpreshuffle untuned M={m}, N={n}, K={k}; " + f"falling back to flydsl kernel '{ki.name}'." + ) + return gemm_a8w8_bpreshuffle_flydsl( + XQ, WQ, x_scale, w_scale, Y, {"kernelName": ki.name} + ) try: return gemm_a8w8_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Y, 0) except RuntimeError as e: From 03f2799a1b7605b3447b99aac009d3db999a5575 Mon Sep 17 00:00:00 2001 From: aoli Date: Mon, 8 Jun 2026 15:22:02 +0000 Subject: [PATCH 18/19] temp commit shapes for auto --- aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv | 40 --------- .../configs/a8w8_bpreshuffle_untuned_gemm.csv | 81 +++++++++++-------- aiter/ops/gemm_op_a8w8.py | 2 + 3 files changed, 51 insertions(+), 72 deletions(-) diff --git a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv index 7ba0373836..9b9f652a22 100644 --- a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv @@ -549,43 +549,3 @@ gfx950,256,4096,57344,8192,torch.float8_e4m3fn,flydsl,979,0,1486.3101,flydsl_bpr gfx950,256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2638.66,506.16,0.0 gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2609.01,420.85,0.0 gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0 -gfx1250,256,16,2112,7168,torch.float8_e4m3fn,flydsl,25,0,30.8522,flydsl_bpreshuffle_wmma_t32x32x256_mw2_nw2_nb3_sk1_cm1_cn1,15.7,496.6,0.011 -gfx1250,256,32,2112,7168,torch.float8_e4m3fn,flydsl,286,0,28.9229,flydsl_bpreshuffle_wmma_t32x64x256_mw2_nw2_nb4_sk1_cm1_cn1,33.5,536.02,0.0105 -gfx1250,256,64,2112,7168,torch.float8_e4m3fn,flydsl,269,0,31.7564,flydsl_bpreshuffle_wmma_t32x32x256_mw1_nw2_nb4_sk1_cm1_cn1,61.02,499.68,0.0098 -gfx1250,256,128,2112,7168,torch.float8_e4m3fn,flydsl,9,0,8.326,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,465.47,1993.39,0.0096 -gfx1250,256,256,2112,7168,torch.float8_e4m3fn,flydsl,19,0,25.7514,flydsl_bpreshuffle_wmma_t128x192x128_mw2_nw2_nb4_sk1_cm1_cn1,301.0,701.13,0.0098 -gfx1250,256,512,2112,7168,torch.float8_e4m3fn,flydsl,398,0,40.0855,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,386.73,523.17,0.0098 -gfx1250,256,1024,2112,7168,torch.float8_e4m3fn,flydsl,398,0,68.7538,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,450.95,389.86,0.0096 -gfx1250,256,2048,2112,7168,torch.float8_e4m3fn,flydsl,398,0,126.0903,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,491.78,305.1,0.0098 -gfx1250,256,4096,2112,7168,torch.float8_e4m3fn,flydsl,398,0,240.7634,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,515.1,256.69,0.0098 -gfx1250,256,8192,2112,7168,torch.float8_e4m3fn,flydsl,398,0,470.1095,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,527.61,230.72,0.0098 -gfx1250,256,16,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.8317,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,118.06,3704.86,0.0237 -gfx1250,256,32,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.9644,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,235.14,3704.86,0.0237 -gfx1250,256,64,7168,16384,torch.float8_e4m3fn,flydsl,500,0,32.2297,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,466.41,3704.86,0.0237 -gfx1250,256,128,7168,16384,torch.float8_e4m3fn,flydsl,500,0,32.7604,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,917.72,3704.86,0.0237 -gfx1250,256,256,7168,16384,torch.float8_e4m3fn,flydsl,500,0,33.8217,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,1777.84,3704.86,0.0237 -gfx1250,256,512,7168,16384,torch.float8_e4m3fn,flydsl,500,0,51.8551,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,2319.14,2568.1,0.0237 -gfx1250,256,1024,7168,16384,torch.float8_e4m3fn,flydsl,500,0,92.0588,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,2612.66,1617.42,0.0237 -gfx1250,256,2048,7168,16384,torch.float8_e4m3fn,flydsl,302,0,172.4662,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2789.16,1045.74,0.0237 -gfx1250,256,4096,7168,16384,torch.float8_e4m3fn,flydsl,302,0,327.0773,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2941.42,743.77,0.0236 -gfx1250,256,8192,7168,16384,torch.float8_e4m3fn,flydsl,302,0,640.4273,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3004.47,576.33,0.0237 -gfx1250,256,16,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.8569,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,31.09,992.35,0.0 -gfx1250,256,32,24576,1536,torch.float8_e4m3fn,flydsl,302,0,39.2634,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,61.53,1002.73,0.0 -gfx1250,256,64,24576,1536,torch.float8_e4m3fn,flydsl,302,0,40.0765,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,120.57,1022.86,0.0 -gfx1250,256,128,24576,1536,torch.float8_e4m3fn,flydsl,302,0,41.7027,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,231.73,1060.76,0.0 -gfx1250,256,256,24576,1536,torch.float8_e4m3fn,flydsl,302,0,44.9552,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,429.92,1128.34,0.0 -gfx1250,256,512,24576,1536,torch.float8_e4m3fn,flydsl,302,0,51.4601,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,751.16,1237.87,0.0 -gfx1250,256,1024,24576,1536,torch.float8_e4m3fn,flydsl,302,0,64.4699,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1199.16,1390.62,0.0 -gfx1250,256,2048,24576,1536,torch.float8_e4m3fn,flydsl,302,0,90.4895,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1708.69,1564.36,0.0 -gfx1250,256,4096,24576,1536,torch.float8_e4m3fn,flydsl,302,0,146.1362,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2116.09,1679.03,0.0 -gfx1250,256,8192,24576,1536,torch.float8_e4m3fn,flydsl,302,0,257.4906,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2401.93,1759.23,0.0 -gfx1250,256,16,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5724,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,11.53,382.93,0.0 -gfx1250,256,32,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.638,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,23.02,405.05,0.0 -gfx1250,256,64,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.7691,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,45.92,449.11,0.0 -gfx1250,256,128,32768,512,torch.float8_e4m3fn,flydsl,302,0,47.0315,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,91.32,536.48,0.0 -gfx1250,256,256,32768,512,torch.float8_e4m3fn,flydsl,302,0,47.5561,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,180.63,708.33,0.0 -gfx1250,256,512,32768,512,torch.float8_e4m3fn,flydsl,302,0,48.6054,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,353.46,1040.91,0.0 -gfx1250,256,1024,32768,512,torch.float8_e4m3fn,flydsl,302,0,50.7041,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,677.65,1664.76,0.0 -gfx1250,256,2048,32768,512,torch.float8_e4m3fn,flydsl,302,0,54.9014,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1251.69,2769.39,0.0 -gfx1250,256,4096,32768,512,torch.float8_e4m3fn,flydsl,302,0,88.9453,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1545.21,3230.19,0.0 -gfx1250,256,8192,32768,512,torch.float8_e4m3fn,flydsl,302,0,150.5706,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1825.57,3704.86,0.0 diff --git a/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv index 119932fa65..c6521c6762 100644 --- a/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv @@ -1,33 +1,50 @@ M,N,K,q_dtype_w -64,1536,5120,torch.int8 -128,1536,5120,torch.int8 -256,1536,5120,torch.int8 -512,1536,5120,torch.int8 -1024,1536,5120,torch.int8 -1664,1536,5120,torch.int8 -4096,1536,5120,torch.int8 -8192,1536,5120,torch.int8 -10240,1536,5120,torch.int8 -12288,1536,5120,torch.int8 -16384,1536,5120,torch.int8 -20480,1536,5120,torch.int8 -24576,1536,5120,torch.int8 -30720,1536,5120,torch.int8 -32768,1536,5120,torch.int8 -40960,1536,5120,torch.int8 -64,5120,1280,torch.int8 -128,5120,1280,torch.int8 -256,5120,1280,torch.int8 -512,5120,1280,torch.int8 -1024,5120,1280,torch.int8 -1664,5120,1280,torch.int8 -4096,5120,1280,torch.int8 -8192,5120,1280,torch.int8 -10240,5120,1280,torch.int8 -12288,5120,1280,torch.int8 -16384,5120,1280,torch.int8 -20480,5120,1280,torch.int8 -24576,5120,1280,torch.int8 -30720,5120,1280,torch.int8 -32768,5120,1280,torch.int8 -40960,5120,1280,torch.int8 +16,2112,7168,torch.float8_e4m3fn +32,2112,7168,torch.float8_e4m3fn +64,2112,7168,torch.float8_e4m3fn +128,2112,7168,torch.float8_e4m3fn +256,2112,7168,torch.float8_e4m3fn +512,2112,7168,torch.float8_e4m3fn +1024,2112,7168,torch.float8_e4m3fn +2048,2112,7168,torch.float8_e4m3fn +4096,2112,7168,torch.float8_e4m3fn +8192,2112,7168,torch.float8_e4m3fn +16384,2112,7168,torch.float8_e4m3fn +32768,2112,7168,torch.float8_e4m3fn +16,24576,1536,torch.float8_e4m3fn +32,24576,1536,torch.float8_e4m3fn +64,24576,1536,torch.float8_e4m3fn +128,24576,1536,torch.float8_e4m3fn +256,24576,1536,torch.float8_e4m3fn +512,24576,1536,torch.float8_e4m3fn +1024,24576,1536,torch.float8_e4m3fn +2048,24576,1536,torch.float8_e4m3fn +4096,24576,1536,torch.float8_e4m3fn +8192,24576,1536,torch.float8_e4m3fn +16384,24576,1536,torch.float8_e4m3fn +32768,24576,1536,torch.float8_e4m3fn +16,32768,512,torch.float8_e4m3fn +32,32768,512,torch.float8_e4m3fn +64,32768,512,torch.float8_e4m3fn +128,32768,512,torch.float8_e4m3fn +256,32768,512,torch.float8_e4m3fn +512,32768,512,torch.float8_e4m3fn +1024,32768,512,torch.float8_e4m3fn +2048,32768,512,torch.float8_e4m3fn +4096,32768,512,torch.float8_e4m3fn +8192,32768,512,torch.float8_e4m3fn +16384,32768,512,torch.float8_e4m3fn +32768,32768,512,torch.float8_e4m3fn +16,7168,16384,torch.float8_e4m3fn +32,7168,16384,torch.float8_e4m3fn +64,7168,16384,torch.float8_e4m3fn +128,7168,16384,torch.float8_e4m3fn +256,7168,16384,torch.float8_e4m3fn +512,7168,16384,torch.float8_e4m3fn +1024,7168,16384,torch.float8_e4m3fn +2048,7168,16384,torch.float8_e4m3fn +4096,7168,16384,torch.float8_e4m3fn +8192,7168,16384,torch.float8_e4m3fn +16384,7168,16384,torch.float8_e4m3fn +32768,7168,16384,torch.float8_e4m3fn + diff --git a/aiter/ops/gemm_op_a8w8.py b/aiter/ops/gemm_op_a8w8.py index 2d43f58f75..8192f715f1 100644 --- a/aiter/ops/gemm_op_a8w8.py +++ b/aiter/ops/gemm_op_a8w8.py @@ -474,6 +474,8 @@ def get_GEMM_config_with_quant_type( msg = f"shape M:{M}, N:{N}, K:{K} q_dtype_w:{q_dtype_w}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned, in {tuned_file}!" if "libtype" in config: msg += f" libtype is {config['libtype']}!" + if "kernelName" in config: + msg += f" kernelName is {config['kernelName']} (kernelId {config.get('kernelId')})!" logger.info(msg) break if config is None: From c5da72d0b5f5aca06cb9260109d54a7491ccd420 Mon Sep 17 00:00:00 2001 From: aoli Date: Mon, 8 Jun 2026 15:25:51 +0000 Subject: [PATCH 19/19] temp disable splitk autotune --- .../gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py index ea2f7e5202..940453f2c7 100644 --- a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py +++ b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py @@ -48,7 +48,7 @@ (256, 256, 128, 2, 2, 4), (256, 256, 128, 4, 4, 3), (256, 512, 128, 2, 4, 3), ) # fmt: on -_SPLIT_K = (1, 2, 4) +_SPLIT_K = (1,) @dataclass