diff --git a/aiter/ops/pa_sparse_prefill_opus.py b/aiter/ops/pa_sparse_prefill_opus.py index 0af2627776..996e5d7ea1 100644 --- a/aiter/ops/pa_sparse_prefill_opus.py +++ b/aiter/ops/pa_sparse_prefill_opus.py @@ -142,7 +142,202 @@ def pa_sparse_prefill_opus( return out +# --------------------------------------------------------------------------- +# FP8 (DeepSeek-V4 / asm-v4 layout) variant. +# head dim 512 = NOPE(448, fp8 e4m3, per-64-tile fp32 scale) + ROPE(64, bf16), +# for both Q and KV. First implementation dequantizes to bf16 scratch (via a +# standalone device kernel) then runs the bf16 attention kernel. +# --------------------------------------------------------------------------- + +_FP8_NOPE = 448 +_FP8_ROPE = 64 +_FP8_FULL = 512 +_FP8_NUM_TILES = 7 + + +@compile_ops("module_pa_sparse_prefill_opus", develop=True) +def pa_sparse_prefill_opus_fp8_fwd( + q_nope: torch.Tensor, + q_rope: torch.Tensor, + q_scale: torch.Tensor, + unified_kv_nope: torch.Tensor, + unified_kv_rope: torch.Tensor, + unified_kv_scale: torch.Tensor, + kv_nope: torch.Tensor, + kv_rope: torch.Tensor, + kv_scale: torch.Tensor, + kv_indices_prefix: torch.Tensor, + kv_indptr_prefix: torch.Tensor, + kv_indices_extend: torch.Tensor, + kv_indptr_extend: torch.Tensor, + attn_sink: torch.Tensor, + q_bf16: torch.Tensor, + unified_kv_bf16: torch.Tensor, + kv_bf16: torch.Tensor, + out: torch.Tensor, + softmax_scale: float, +) -> None: ... + + +def pa_sparse_prefill_opus_fp8( + q_nope: torch.Tensor, + q_rope: torch.Tensor, + q_scale: torch.Tensor, + unified_kv_nope: torch.Tensor, + unified_kv_rope: torch.Tensor, + unified_kv_scale: torch.Tensor, + kv_nope: torch.Tensor, + kv_rope: torch.Tensor, + kv_scale: torch.Tensor, + kv_indices_prefix: torch.Tensor, + kv_indptr_prefix: torch.Tensor, + kv_indices_extend: torch.Tensor, + kv_indptr_extend: torch.Tensor, + attn_sink: torch.Tensor, + softmax_scale: float, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """FP8 (v4-layout) sparse prefill attention. + + Q and KV are stored as NOPE(448, fp8 e4m3, per-64-tile fp32 scale) + + ROPE(64, bf16). Output is bf16 ``[N, H, 512]``. The bf16 dequant scratch + is allocated here and freed when it goes out of scope. + + Args mirror :func:`pa_sparse_prefill_opus` but with each q/unified_kv/kv + split into ``*_nope`` (fp8 ``[..., 448]``), ``*_rope`` (bf16 ``[..., 64]``) + and ``*_scale`` (fp32 ``[..., 7]``). + """ + gfx = get_gfx_runtime() + if gfx != "gfx950": + raise RuntimeError(f"pa_sparse_prefill_opus_fp8 requires gfx950, got {gfx}") + + if q_nope.dtype != torch.float8_e4m3fn: + raise RuntimeError(f"q_nope must be fp8 e4m3fn, got {q_nope.dtype}") + if q_nope.size(-1) != _FP8_NOPE or q_rope.size(-1) != _FP8_ROPE: + raise RuntimeError( + f"expected nope last dim {_FP8_NOPE} and rope last dim {_FP8_ROPE}, " + f"got {q_nope.size(-1)} / {q_rope.size(-1)}" + ) + + n, h = q_nope.size(0), q_nope.size(1) + total_pages = unified_kv_nope.size(0) + total_tokens = kv_nope.size(0) + dev = q_nope.device + + q_bf16 = torch.empty((n, h, _FP8_FULL), dtype=torch.bfloat16, device=dev) + unified_kv_bf16 = torch.empty( + (total_pages, _FP8_FULL), dtype=torch.bfloat16, device=dev + ) + kv_bf16 = torch.empty((total_tokens, _FP8_FULL), dtype=torch.bfloat16, device=dev) + if out is None: + out = torch.empty((n, h, _FP8_FULL), dtype=torch.bfloat16, device=dev) + + pa_sparse_prefill_opus_fp8_fwd( + q_nope, + q_rope, + q_scale, + unified_kv_nope, + unified_kv_rope, + unified_kv_scale, + kv_nope, + kv_rope, + kv_scale, + kv_indices_prefix, + kv_indptr_prefix, + kv_indices_extend, + kv_indptr_extend, + attn_sink, + q_bf16, + unified_kv_bf16, + kv_bf16, + out, + float(softmax_scale), + ) + return out + + +@compile_ops("module_pa_sparse_prefill_opus", develop=True) +def pa_sparse_prefill_opus_fp8_fused_fwd( + q_nope: torch.Tensor, + q_rope: torch.Tensor, + q_scale: torch.Tensor, + unified_kv_nope: torch.Tensor, + unified_kv_rope: torch.Tensor, + unified_kv_scale: torch.Tensor, + kv_nope: torch.Tensor, + kv_rope: torch.Tensor, + kv_scale: torch.Tensor, + kv_indices_prefix: torch.Tensor, + kv_indptr_prefix: torch.Tensor, + kv_indices_extend: torch.Tensor, + kv_indptr_extend: torch.Tensor, + attn_sink: torch.Tensor, + out: torch.Tensor, + softmax_scale: float, +) -> None: ... + + +def pa_sparse_prefill_opus_fp8_fused( + q_nope: torch.Tensor, + q_rope: torch.Tensor, + q_scale: torch.Tensor, + unified_kv_nope: torch.Tensor, + unified_kv_rope: torch.Tensor, + unified_kv_scale: torch.Tensor, + kv_nope: torch.Tensor, + kv_rope: torch.Tensor, + kv_scale: torch.Tensor, + kv_indices_prefix: torch.Tensor, + kv_indptr_prefix: torch.Tensor, + kv_indices_extend: torch.Tensor, + kv_indptr_extend: torch.Tensor, + attn_sink: torch.Tensor, + softmax_scale: float, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Fused fp8 (v4-layout) sparse prefill attention (no bf16 KV scratch). + + QK-nope runs in fp8 MFMA with software per-64-tile scale; QK-rope in bf16; + PV in bf16 with on-chip V dequant. H must be a multiple of 16. + """ + gfx = get_gfx_runtime() + if gfx != "gfx950": + raise RuntimeError( + f"pa_sparse_prefill_opus_fp8_fused requires gfx950, got {gfx}" + ) + if q_nope.dtype != torch.float8_e4m3fn: + raise RuntimeError(f"q_nope must be fp8 e4m3fn, got {q_nope.dtype}") + n, h = q_nope.size(0), q_nope.size(1) + if h % 16 != 0: + raise RuntimeError(f"H must be a multiple of 16, got {h}") + if out is None: + out = torch.empty((n, h, _FP8_FULL), dtype=torch.bfloat16, device=q_nope.device) + pa_sparse_prefill_opus_fp8_fused_fwd( + q_nope, + q_rope, + q_scale, + unified_kv_nope, + unified_kv_rope, + unified_kv_scale, + kv_nope, + kv_rope, + kv_scale, + kv_indices_prefix, + kv_indptr_prefix, + kv_indices_extend, + kv_indptr_extend, + attn_sink, + out, + float(softmax_scale), + ) + return out + + __all__ = [ "pa_sparse_prefill_opus_fwd", "pa_sparse_prefill_opus", + "pa_sparse_prefill_opus_fp8_fwd", + "pa_sparse_prefill_opus_fp8", + "pa_sparse_prefill_opus_fp8_fused_fwd", + "pa_sparse_prefill_opus_fp8_fused", ] diff --git a/csrc/include/pa_sparse_prefill_opus.h b/csrc/include/pa_sparse_prefill_opus.h index 0805902ed5..ee99c7b41a 100644 --- a/csrc/include/pa_sparse_prefill_opus.h +++ b/csrc/include/pa_sparse_prefill_opus.h @@ -36,6 +36,74 @@ void pa_sparse_prefill_opus_fwd(aiter_tensor_t& q, aiter_tensor_t& out, float softmax_scale); +// FP8 (DeepSeek-V4 / asm-v4 layout) variant of the prefill attention. +// +// Each head-dim row of D=512 is stored mixed-precision: +// NOPE : first 448 dims, FP8 (e4m3), with a per-64-element-tile fp32 scale +// (7 tiles -> 7 scales; values are e8m0-rounded powers of two). +// ROPE : last 64 dims, BF16 (never quantized). +// Applies to BOTH Q and KV (unified_kv + kv). +// +// Tensor expectations (row-major, last dim contiguous): +// q_nope : [N, H, 448] fp8 +// q_rope : [N, H, 64] bf16 +// q_scale : [N, H, 7] fp32 +// unified_kv_nope : [total_pages, 448] fp8 +// unified_kv_rope : [total_pages, 64] bf16 +// unified_kv_scale : [total_pages, 7] fp32 +// kv_nope : [total_tokens, 448] fp8 +// kv_rope : [total_tokens, 64] bf16 +// kv_scale : [total_tokens, 7] fp32 +// kv_indices/indptr : int32 (prefix + extend), as in the bf16 entry +// attn_sink : [H] fp32 +// q_bf16 / unified_kv_bf16 / kv_bf16 : caller-allocated bf16 dequant scratch +// ([N,H,512] / [total_pages,512] / [total_tokens,512]) +// out : [N, H, 512] bf16 (caller-allocated) +// +// This first implementation dequantizes q/unified_kv/kv into the bf16 scratch +// via a standalone device kernel, then runs the existing bf16 attention kernel +// on the scratch (a fused FP8 attention kernel is a follow-up). +void pa_sparse_prefill_opus_fp8_fwd(aiter_tensor_t& q_nope, + aiter_tensor_t& q_rope, + aiter_tensor_t& q_scale, + aiter_tensor_t& unified_kv_nope, + aiter_tensor_t& unified_kv_rope, + aiter_tensor_t& unified_kv_scale, + aiter_tensor_t& kv_nope, + aiter_tensor_t& kv_rope, + aiter_tensor_t& kv_scale, + aiter_tensor_t& kv_indices_prefix, + aiter_tensor_t& kv_indptr_prefix, + aiter_tensor_t& kv_indices_extend, + aiter_tensor_t& kv_indptr_extend, + aiter_tensor_t& attn_sink, + aiter_tensor_t& q_bf16, + aiter_tensor_t& unified_kv_bf16, + aiter_tensor_t& kv_bf16, + aiter_tensor_t& out, + float softmax_scale); + +// FUSED fp8 (v4-layout) prefill: reads fp8 nope + bf16 rope + fp32 scale directly, +// does QK-nope in fp8 MFMA (software per-64-tile scale) + QK-rope in bf16, and +// bf16 PV with on-chip V dequant. No bf16 KV scratch. H must be a multiple of 16. +// out is bf16 [N, H, 512]. +void pa_sparse_prefill_opus_fp8_fused_fwd(aiter_tensor_t& q_nope, + aiter_tensor_t& q_rope, + aiter_tensor_t& q_scale, + aiter_tensor_t& unified_kv_nope, + aiter_tensor_t& unified_kv_rope, + aiter_tensor_t& unified_kv_scale, + aiter_tensor_t& kv_nope, + aiter_tensor_t& kv_rope, + aiter_tensor_t& kv_scale, + aiter_tensor_t& kv_indices_prefix, + aiter_tensor_t& kv_indptr_prefix, + aiter_tensor_t& kv_indices_extend, + aiter_tensor_t& kv_indptr_extend, + aiter_tensor_t& attn_sink, + aiter_tensor_t& out, + float softmax_scale); + #ifdef PA_SPARSE_PREFILL_OPUS_IMPL // ============================================================================ // Implementation section - only compiled in the .cu translation unit @@ -67,6 +135,59 @@ struct pa_sparse_prefill_kargs float softmax_scale; }; +// Kernel arguments for the FUSED fp8 (v4-layout) prefill kernel. +// Q/KV split into nope (fp8 [.,448]) + rope (bf16 [.,64]) + scale (fp32 [.,7]). +// Out is bf16 [N,H,512]. Tensors are row-major contiguous (strides derived from H). +struct pa_sparse_prefill_fp8_kargs +{ + const void* __restrict__ q_nope; // [N,H,448] fp8 + const void* __restrict__ q_rope; // [N,H,64] bf16 + const float* __restrict__ q_scale; // [N,H,7] fp32 + const void* __restrict__ ukv_nope; // [total_pages,448] fp8 + const void* __restrict__ ukv_rope; // [total_pages,64] bf16 + const float* __restrict__ ukv_scale; // [total_pages,7] + const void* __restrict__ kv_nope; // [total_tokens,448] fp8 + const void* __restrict__ kv_rope; // [total_tokens,64] bf16 + const float* __restrict__ kv_scale; // [total_tokens,7] + const void* __restrict__ attn_sink; // [H] fp32 + void* __restrict__ out; // [N,H,512] bf16 + const int* __restrict__ kv_indptr_prefix; + const int* __restrict__ kv_indices_prefix; + const int* __restrict__ kv_indptr_extend; + const int* __restrict__ kv_indices_extend; + int N; + int H; + int total_pages; + int total_tokens; + float softmax_scale; +}; + +// Compile-time tile config for the fused fp8 (v4-layout) prefill kernel. +// Head dim DFULL = DNOPE (fp8, per-KTSZ-tile e8m0 scale) + DROPE (bf16). +// One warp handles a QTILE-head tile; up to MAX_WARPS independent head-tiles +// per block. KVTILE must be a multiple of 16; DNOPE a multiple of KTSZ. +template +struct pa_prefill_fp8_traits +{ + static constexpr int QTILE = Q_TILE_; + static constexpr int KVTILE = KV_TILE_; + static constexpr int DNOPE = D_NOPE_; + static constexpr int DROPE = D_ROPE_; + static constexpr int DFULL = D_NOPE_ + D_ROPE_; + static constexpr int KTSZ = KV_SCALE_TILE_; + static constexpr int NTILE = D_NOPE_ / KV_SCALE_TILE_; + static constexpr int NSUB = KV_TILE_ / 16; + static constexpr int MAX_WARPS = MAX_WARPS_; + static constexpr int WARP_SIZE = 64; + static_assert(D_NOPE_ % KV_SCALE_TILE_ == 0, "DNOPE must be a multiple of KTSZ"); + static_assert(KV_TILE_ % 16 == 0, "KVTILE must be a multiple of 16"); +}; + // Compile-time tile/MFMA configuration for the 16mx8_32nx1 variant (T_M=NUM_WARPS, // T_N=1). Used when H > 32. KV_TILE=32, NUM_WARPS=8, BLOCK_SIZE=512. template dequant happens at the smem->reg +// read. All smem geometry derives from sizeof(D_KV); the MFMA tile counts are +// dtype-independent. VEC_KV=8 reads 8 fp8 (8-byte load) and dequants to 8 bf16. +template +struct pa_prefill_16mx8_fp8_traits +{ + static constexpr int Q_TILE_SIZE = Q_TILE_SIZE_; + static constexpr int KV_TILE_SIZE = KV_TILE_SIZE_; + static constexpr int D_TILE_SIZE = D_TILE_SIZE_; + static constexpr int NUM_WARPS = NUM_WARPS_; + + static constexpr int WARP_SIZE = 64; + static constexpr int BLOCK_SIZE = NUM_WARPS * WARP_SIZE; + + using D_KV = D_KV_; // KV storage in LDS (fp8) + using D_ATTN = D_ATTN_; // Q / MFMA operands / O (bf16) + using D_ACC = float; + + static constexpr int T_M = NUM_WARPS; + static constexpr int T_N = 1; + static constexpr int T_K = 1; + + static constexpr int W_M = 16; + static constexpr int W_N = 16; + static constexpr int W_K = 32; + + static constexpr int SLICE_D = 32; + static constexpr int NUM_D_SLICES = D_TILE_SIZE / SLICE_D; + static_assert(D_TILE_SIZE % SLICE_D == 0); + + static constexpr int GEMM0_E_M = Q_TILE_SIZE / W_M; + static constexpr int GEMM0_E_N = KV_TILE_SIZE / W_N; + static constexpr int GEMM0_E_K = SLICE_D / W_K; + + static constexpr int GEMM1_E_M = Q_TILE_SIZE / W_M; + static constexpr int GEMM1_E_N = SLICE_D / W_N; + static constexpr int GEMM1_E_K = KV_TILE_SIZE / W_K; + + static constexpr int VEC_Q = 8; + static constexpr int VEC_KV = 8; + static constexpr int VEC_TR_V = 4; + static constexpr int VEC_O = 4; + + // smem geometry mirrors the bf16 (D_ATTN) ELEMENT layout exactly: the fp8 KV + // tile is a 1-byte-per-element copy of the bf16 layout, so async_load(fp8) + + // u_rk read deliver the SAME (lane,e)->(n,d) the bf16 MFMA b-operand expects + // (the MFMA is bf16; fp8 is dequanted at the read, AFTER placement). Geometry + // is therefore derived from sizeof(D_ATTN); only smem_size_bytes uses D_KV. + static constexpr int D_128B_SIZE = 128 / sizeof(D_ATTN); // 64 (elements/chunk) + static_assert(VEC_KV == 16 / sizeof(D_ATTN), "VEC_KV must match bf16 layout"); + static_assert(VEC_KV * (int)sizeof(D_KV) <= 16, "KV load must be <= 16B"); + static constexpr int smem_linear_wave = WARP_SIZE * 16 / sizeof(D_ATTN); // 512 + static constexpr int smem_n_per_wave = smem_linear_wave / D_128B_SIZE; // 8 + static constexpr int smem_n_rpt = KV_TILE_SIZE / smem_n_per_wave; // 4 + static constexpr int smem_d_rpt = D_TILE_SIZE / D_128B_SIZE; // 8 + static constexpr int smem_padding_32B = 32 / sizeof(D_ATTN); // 16 + static constexpr int smem_kv_tile_elems = + smem_n_rpt * smem_d_rpt * (smem_linear_wave + smem_padding_32B); + + static constexpr int kv_buffer_load_insts = + (KV_TILE_SIZE * D_TILE_SIZE) / (BLOCK_SIZE * VEC_KV); + static constexpr int k_ds_read_insts = + (GEMM0_E_N * GEMM0_E_K * W_N * W_K) / (WARP_SIZE * VEC_KV); + static constexpr int v_ds_read_insts = + (GEMM1_E_N * GEMM1_E_K * W_N * W_K) / (WARP_SIZE * VEC_TR_V); + + static constexpr size_t smem_size_bytes() + { + return 4 * smem_kv_tile_elems * sizeof(D_KV); + } +}; + // Compile-time tile/MFMA configuration for the 16mx1_16nx4 variant (T_M=1, // T_N=NUM_WARPS). Used when H <= 32. KV_TILE=64, NUM_WARPS=4, BLOCK_SIZE=256. template BF16 dequant kernel (v4 layout). +// Per row: out[0:448] = fp8(nope[0:448]) * scale[j/64]; out[448:512] = bf16(rope[0:64]). +struct pa_v4_dequant_kargs +{ + const void* __restrict__ nope_ptr; // fp8 [rows, 448] + const void* __restrict__ rope_ptr; // bf16 [rows, 64] + const float* __restrict__ scale_ptr; // fp32 [rows, 7] + void* __restrict__ out_ptr; // bf16 [rows, 512] + int rows; + int stride_nope; // elems/row (= 448 when contiguous) + int stride_rope; // elems/row (= 64) + int stride_scale; // elems/row (= 7) + int stride_out; // elems/row (= 512) +}; + // Device kernel templates — declared here, defined in the device pass below. template __global__ void pa_prefill_16mx8_32nx1_kernel(pa_sparse_prefill_kargs kargs); template __global__ void pa_prefill_16mx1_16nx4_kernel(pa_sparse_prefill_kargs kargs); +__global__ void pa_v4_fp8_dequant_kernel(pa_v4_dequant_kargs kargs); +template +__global__ void pa_prefill_fp8_fused_kernel(pa_sparse_prefill_fp8_kargs kargs); // Pull in the device kernel template bodies only on the gfx950 device pass. #if !defined(__HIP_DEVICE_COMPILE__) || !defined(__gfx950__) @@ -225,6 +442,13 @@ template __global__ void pa_prefill_16mx1_16nx4_kernel(pa_sparse_prefill_kargs) { } +__global__ void pa_v4_fp8_dequant_kernel(pa_v4_dequant_kargs) +{ +} +template +__global__ void pa_prefill_fp8_fused_kernel(pa_sparse_prefill_fp8_kargs) +{ +} #else // ============================================================================= // Device-side kernel implementation (gfx950 OPUS, D=512). @@ -232,9 +456,283 @@ __global__ void pa_prefill_16mx1_16nx4_kernel(pa_sparse_prefill_kargs) // ============================================================================= #include #include +#include using opus::operator""_I; +// Standalone FP8(e4m3)+e8m0 -> BF16 dequant for the v4 mixed-precision layout. +// One block per row; threads stride over the 512 output elements. +__global__ void pa_v4_fp8_dequant_kernel(pa_v4_dequant_kargs kargs) +{ + constexpr int D_NOPE = 448; + constexpr int D_ROPE = 64; + constexpr int D_FULL = D_NOPE + D_ROPE; // 512 + constexpr int TILE = 64; + + const int row = blockIdx.x; + if (row >= kargs.rows) return; + + const uint8_t* nope = reinterpret_cast(kargs.nope_ptr) + + static_cast(row) * kargs.stride_nope; + const bf16_t* rope = reinterpret_cast(kargs.rope_ptr) + + static_cast(row) * kargs.stride_rope; + const float* scale = kargs.scale_ptr + static_cast(row) * kargs.stride_scale; + bf16_t* out = reinterpret_cast(kargs.out_ptr) + + static_cast(row) * kargs.stride_out; + + for (int j = threadIdx.x; j < D_FULL; j += blockDim.x) + { + float v; + if (j < D_NOPE) + { + // gfx950 hardware fp8 = OCP e4m3fn; byte 0 of the packed int. + const float f = __builtin_amdgcn_cvt_f32_fp8(static_cast(nope[j]), 0); + v = f * scale[j / TILE]; + } + else + { + v = static_cast(rope[j - D_NOPE]); + } + out[j] = static_cast(v); + } +} + +// ============================================================================ +// Fused fp8 (v4-layout) prefill attention — single-warp, correctness-first. +// QK: nope = fp8 16x16x32 MFMA + software per-64-tile scale; rope = bf16 MFMA. +// Softmax: smem-mediated online (per head), per-head sink finalize. +// PV: dequant V (nope*scale ++ rope) -> bf16 smem, bf16 16x16x32 MFMA. +// One block = one query token x a 16-head tile, 64 lanes (1 wave). +// ============================================================================ +template +__global__ __launch_bounds__(256, 1) +void pa_prefill_fp8_fused_kernel(pa_sparse_prefill_fp8_kargs kargs) +{ + using namespace opus; // fp8x8_t / bf16x8_t / fp32x4_t etc. come from here + using T = opus::remove_cvref_t; + constexpr float LOG2_E = 1.44269504089f; + + const int qtok = blockIdx.x; + const int hblk = blockIdx.y; + const int H = kargs.H; + const int nwarp = blockDim.x / 64; // 1..MAX_WARPS independent head-tiles / block + const int warp_id = threadIdx.x / 64; + const int h0 = (hblk * nwarp + warp_id) * T::QTILE; // exact: nwarp | (H/16) + const int lane = threadIdx.x % 64; // 0..63 within warp + const int ml = lane % 16; // A/B load row within tile + const int kg = lane / 16; // 0..3 : K sub-block (8 wide) + const float temp = kargs.softmax_scale * LOG2_E; + + const fp8_t* q_nope = reinterpret_cast(kargs.q_nope); + const bf16_t* q_rope = reinterpret_cast(kargs.q_rope); + const float* q_scale = kargs.q_scale; + + // ---- smem (per-warp slices; sized for up to MAX_WARPS warps) ---- + __shared__ float sS_all[T::MAX_WARPS * T::QTILE * T::KVTILE]; + __shared__ bf16_t sP_all[T::MAX_WARPS * T::QTILE * T::KVTILE]; + __shared__ float s_m_all[T::MAX_WARPS * T::QTILE]; + __shared__ float s_l_all[T::MAX_WARPS * T::QTILE]; + __shared__ float s_corr_all[T::MAX_WARPS * T::QTILE]; + float* sS = sS_all + warp_id * (T::QTILE * T::KVTILE); + bf16_t* sP = sP_all + warp_id * (T::QTILE * T::KVTILE); + float* s_m = s_m_all + warp_id * T::QTILE; + float* s_l = s_l_all + warp_id * T::QTILE; + float* s_corr = s_corr_all + warp_id * T::QTILE; + + // ---- load Q for the QTILE-head tile (reused across all KV tiles) ---- + // A-load row = head ml; per nope tile t, sub-chunk kk; q_rope likewise. + fp8x8_t q8[T::NTILE][2]; + bf16x8_t qr[2]; + const int q_head = h0 + ml; + const fp8_t* qn_ptr = q_nope + (size_t)(qtok * H + q_head) * T::DNOPE; + const bf16_t* qr_ptr = q_rope + (size_t)(qtok * H + q_head) * T::DROPE; + for (int t = 0; t < T::NTILE; ++t) + for (int kk = 0; kk < 2; ++kk) { + int kb = t * T::KTSZ + kk * 32 + kg * 8; +#pragma unroll + for (int j = 0; j < 8; ++j) q8[t][kk][j] = qn_ptr[kb + j]; + } + for (int kk = 0; kk < 2; ++kk) { + int kb = kk * 32 + kg * 8; +#pragma unroll + for (int j = 0; j < 8; ++j) qr[kk][j] = qr_ptr[kb + j]; + } + // q_scale for the 4 OUTPUT heads this lane owns: (kg*4 + i), all NTILE tiles. + float qsc[4][T::NTILE]; + for (int i = 0; i < 4; ++i) + for (int t = 0; t < T::NTILE; ++t) + qsc[i][t] = q_scale[(size_t)(qtok * H + (h0 + kg * 4 + i)) * T::NTILE + t]; + + // ---- O accumulator (fp32) + online state ---- + // v_o[dsub*4 + i] = O[head=kg*4+i][d = dsub*16 + ml]; (DFULL/16) dsub * 4 / lane. + float v_o[T::DFULL / 16 * 4]; +#pragma unroll + for (int e = 0; e < T::DFULL / 16 * 4; ++e) v_o[e] = 0.f; + if (lane < T::QTILE) { s_m[lane] = -1e30f; s_l[lane] = 0.f; } + __builtin_amdgcn_s_barrier(); + + // ---- per-segment accumulate ---- + auto run_segment = [&](const fp8_t* kn_base, const bf16_t* kr_base, + const float* ksc_base, const int* indptr, const int* indices) { + const int beg = indptr[qtok]; + const int end = indptr[qtok + 1]; + const int vlen = end - beg; + const int ntiles = (vlen + T::KVTILE - 1) / T::KVTILE; + + for (int tile = 0; tile < ntiles; ++tile) { + // ---- QK: produce sS[QTILE, KVTILE] ---- + for (int s = 0; s < T::NSUB; ++s) { + int col = s * 16 + ml; // KV column within tile + int pos = tile * T::KVTILE + s * 16 + ml; // position in segment + int valid = (pos < vlen); + int row = valid ? indices[beg + pos] : 0; + const fp8_t* kn = kn_base + (size_t)row * T::DNOPE; + const bf16_t* kr = kr_base + (size_t)row * T::DROPE; + const float* ks = ksc_base + (size_t)row * T::NTILE; + + float acc[4] = {0.f, 0.f, 0.f, 0.f}; + // nope: fp8 MFMA per 64-tile + sw scale + for (int t = 0; t < T::NTILE; ++t) { + fp32x4_t vc{0.f, 0.f, 0.f, 0.f}; + for (int kk = 0; kk < 2; ++kk) { + int kb = t * T::KTSZ + kk * 32 + kg * 8; + fp8x8_t b_reg; +#pragma unroll + for (int j = 0; j < 8; ++j) b_reg[j] = kn[kb + j]; + vc = mfma{}(q8[t][kk], b_reg, vc); + } + float ksc_t = ks[t]; +#pragma unroll + for (int i = 0; i < 4; ++i) acc[i] += vc[i] * qsc[i][t] * ksc_t; + } + // rope: bf16 MFMA (no scale) + { + fp32x4_t vc{0.f, 0.f, 0.f, 0.f}; + for (int kk = 0; kk < 2; ++kk) { + int kb = kk * 32 + kg * 8; + bf16x8_t b_reg; +#pragma unroll + for (int j = 0; j < 8; ++j) b_reg[j] = kr[kb + j]; + vc = mfma{}(qr[kk], b_reg, vc); + } +#pragma unroll + for (int i = 0; i < 4; ++i) acc[i] += vc[i]; + } + // store to sS: row = output head kg*4+i, col +#pragma unroll + for (int i = 0; i < 4; ++i) + sS[(kg * 4 + i) * T::KVTILE + col] = valid ? acc[i] : -1e30f; + } + __builtin_amdgcn_s_barrier(); + + // ---- online softmax (lanes 0..QTILE-1 own one head each) ---- + if (lane < T::QTILE) { + int head = lane; + float mx = s_m[head]; +#pragma unroll + for (int c = 0; c < T::KVTILE; ++c) { + float v = sS[head * T::KVTILE + c]; + if (v > -1e29f) mx = max(mx, v * temp); + } + float corr = __builtin_amdgcn_exp2f(s_m[head] - mx); + float ltile = 0.f; +#pragma unroll + for (int c = 0; c < T::KVTILE; ++c) { + float v = sS[head * T::KVTILE + c]; + float p = (v > -1e29f) ? __builtin_amdgcn_exp2f(v * temp - mx) : 0.f; + sP[head * T::KVTILE + c] = (bf16_t)p; + ltile += p; + } + s_l[head] = s_l[head] * corr + ltile; + s_m[head] = mx; + s_corr[head] = corr; + } + __builtin_amdgcn_s_barrier(); + + // ---- rescale O by per-head correction ---- +#pragma unroll + for (int i = 0; i < 4; ++i) { + float corr = s_corr[kg * 4 + i]; + for (int dsub = 0; dsub < T::DFULL / 16; ++dsub) + v_o[dsub * 4 + i] *= corr; + } + + // ---- PV: O += P[QTILE,KVTILE] @ V[KVTILE,DFULL], V dequant INLINE ---- + // (no LDS V staging: each V[n][d] is consumed by exactly one lane.) + // KVTILE = KC * 32 contraction chunks; each lane's chunk c needs KV + // rows n = c*32 + kg*8 + j (j=0..7). + constexpr int KC = T::KVTILE / 32; + const unsigned char* kn_bytes = reinterpret_cast(kn_base); + int vrow[KC][8]; +#pragma unroll + for (int c = 0; c < KC; ++c) +#pragma unroll + for (int j = 0; j < 8; ++j) { + int pos = tile * T::KVTILE + c * 32 + kg * 8 + j; + vrow[c][j] = (pos < vlen) ? indices[beg + pos] : -1; + } + bf16x8_t p_reg[KC]; +#pragma unroll + for (int c = 0; c < KC; ++c) +#pragma unroll + for (int j = 0; j < 8; ++j) + p_reg[c][j] = sP[ml * T::KVTILE + c * 32 + kg * 8 + j]; + + for (int dsub = 0; dsub < T::DFULL / 16; ++dsub) { + int d = dsub * 16 + ml; + fp32x4_t vc{0.f, 0.f, 0.f, 0.f}; +#pragma unroll + for (int c = 0; c < KC; ++c) { + bf16x8_t v_reg; +#pragma unroll + for (int j = 0; j < 8; ++j) { + int row = vrow[c][j]; + float val = 0.f; + if (row >= 0) { + if (d < T::DNOPE) { + float f = __builtin_amdgcn_cvt_f32_fp8( + (int)kn_bytes[(size_t)row * T::DNOPE + d], 0); + val = f * ksc_base[(size_t)row * T::NTILE + (d / T::KTSZ)]; + } else { + val = (float)kr_base[(size_t)row * T::DROPE + (d - T::DNOPE)]; + } + } + v_reg[j] = (bf16_t)val; + } + vc = mfma{}(p_reg[c], v_reg, vc); + } +#pragma unroll + for (int i = 0; i < 4; ++i) v_o[dsub * 4 + i] += vc[i]; + } + __builtin_amdgcn_s_barrier(); + } + }; + + run_segment(reinterpret_cast(kargs.ukv_nope), + reinterpret_cast(kargs.ukv_rope), + kargs.ukv_scale, kargs.kv_indptr_prefix, kargs.kv_indices_prefix); + run_segment(reinterpret_cast(kargs.kv_nope), + reinterpret_cast(kargs.kv_rope), + kargs.kv_scale, kargs.kv_indptr_extend, kargs.kv_indices_extend); + + // ---- sink finalize + normalize + store ---- + const float* sink = reinterpret_cast(kargs.attn_sink); + bf16_t* out = reinterpret_cast(kargs.out); +#pragma unroll + for (int i = 0; i < 4; ++i) { + int head = kg * 4 + i; + float sink_log2 = sink[h0 + head] * LOG2_E; + float m_final = max(s_m[head], sink_log2); + float alpha = __builtin_amdgcn_exp2f(s_m[head] - m_final); + float l_final = s_l[head] * alpha + __builtin_amdgcn_exp2f(sink_log2 - m_final); + float o_scale = (l_final > 0.f) ? (alpha / l_final) : 0.f; + for (int dsub = 0; dsub < T::DFULL / 16; ++dsub) { + float o = v_o[dsub * 4 + i] * o_scale; + out[((size_t)(qtok * H + (h0 + head)) * T::DFULL) + dsub * 16 + ml] = (bf16_t)o; + } + } +} + // ============================================================================= // Variant 16mx8_32nx1 (T_M=NUM_WARPS, T_N=1) — used when H > 32. // ============================================================================= diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 8693a40f20..9b9424eec4 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1328,6 +1328,49 @@ namespace py = pybind11; py::arg("out"), \ py::arg("softmax_scale")); +#define PA_SPARSE_PREFILL_OPUS_FP8_PYBIND \ + m.def("pa_sparse_prefill_opus_fp8_fwd", \ + &pa_sparse_prefill_opus_fp8_fwd, \ + py::arg("q_nope"), \ + py::arg("q_rope"), \ + py::arg("q_scale"), \ + py::arg("unified_kv_nope"), \ + py::arg("unified_kv_rope"), \ + py::arg("unified_kv_scale"), \ + py::arg("kv_nope"), \ + py::arg("kv_rope"), \ + py::arg("kv_scale"), \ + py::arg("kv_indices_prefix"), \ + py::arg("kv_indptr_prefix"), \ + py::arg("kv_indices_extend"), \ + py::arg("kv_indptr_extend"), \ + py::arg("attn_sink"), \ + py::arg("q_bf16"), \ + py::arg("unified_kv_bf16"), \ + py::arg("kv_bf16"), \ + py::arg("out"), \ + py::arg("softmax_scale")); + +#define PA_SPARSE_PREFILL_OPUS_FP8_FUSED_PYBIND \ + m.def("pa_sparse_prefill_opus_fp8_fused_fwd", \ + &pa_sparse_prefill_opus_fp8_fused_fwd, \ + py::arg("q_nope"), \ + py::arg("q_rope"), \ + py::arg("q_scale"), \ + py::arg("unified_kv_nope"), \ + py::arg("unified_kv_rope"), \ + py::arg("unified_kv_scale"), \ + py::arg("kv_nope"), \ + py::arg("kv_rope"), \ + py::arg("kv_scale"), \ + py::arg("kv_indices_prefix"), \ + py::arg("kv_indptr_prefix"), \ + py::arg("kv_indices_extend"), \ + py::arg("kv_indptr_extend"), \ + py::arg("attn_sink"), \ + py::arg("out"), \ + py::arg("softmax_scale")); + #define NORM_PYBIND \ m.def("layernorm2d_fwd", \ &layernorm2d, \ diff --git a/csrc/py_itfs_cu/pa_sparse_prefill_opus_kernels.cu b/csrc/py_itfs_cu/pa_sparse_prefill_opus_kernels.cu index 996d650170..09abacbd37 100644 --- a/csrc/py_itfs_cu/pa_sparse_prefill_opus_kernels.cu +++ b/csrc/py_itfs_cu/pa_sparse_prefill_opus_kernels.cu @@ -130,3 +130,173 @@ void pa_sparse_prefill_opus_fwd(aiter_tensor_t& q, #undef LAUNCH_PA_PREFILL } + +// ============================================================================= +// FP8 (DeepSeek-V4 / asm-v4 layout) entry: dequant q/unified_kv/kv into bf16 +// scratch via the standalone dequant kernel, then run the bf16 attention. +// ============================================================================= +void pa_sparse_prefill_opus_fp8_fwd(aiter_tensor_t& q_nope, + aiter_tensor_t& q_rope, + aiter_tensor_t& q_scale, + aiter_tensor_t& unified_kv_nope, + aiter_tensor_t& unified_kv_rope, + aiter_tensor_t& unified_kv_scale, + aiter_tensor_t& kv_nope, + aiter_tensor_t& kv_rope, + aiter_tensor_t& kv_scale, + aiter_tensor_t& kv_indices_prefix, + aiter_tensor_t& kv_indptr_prefix, + aiter_tensor_t& kv_indices_extend, + aiter_tensor_t& kv_indptr_extend, + aiter_tensor_t& attn_sink, + aiter_tensor_t& q_bf16, + aiter_tensor_t& unified_kv_bf16, + aiter_tensor_t& kv_bf16, + aiter_tensor_t& out, + float softmax_scale) +{ + constexpr int D_NOPE = 448; + constexpr int D_ROPE = 64; + constexpr int D_FULL = 512; + constexpr int NUM_TILES = 7; + + // ---- dtype validation ------------------------------------------------- + auto check_nope = [&](aiter_tensor_t& t, const char* nm) { + AITER_CHECK(t.dtype() == AITER_DTYPE_fp8, nm, " must be fp8 (e4m3)"); + AITER_CHECK(t.is_contiguous(), nm, " must be contiguous"); + AITER_CHECK(t.size(t.dim() - 1) == D_NOPE, nm, " last dim must be 448"); + }; + auto check_rope = [&](aiter_tensor_t& t, const char* nm) { + AITER_CHECK(t.dtype() == AITER_DTYPE_bf16, nm, " must be bf16"); + AITER_CHECK(t.is_contiguous(), nm, " must be contiguous"); + AITER_CHECK(t.size(t.dim() - 1) == D_ROPE, nm, " last dim must be 64"); + }; + auto check_scale = [&](aiter_tensor_t& t, const char* nm) { + AITER_CHECK(t.dtype() == AITER_DTYPE_fp32, nm, " must be fp32"); + AITER_CHECK(t.is_contiguous(), nm, " must be contiguous"); + AITER_CHECK(t.size(t.dim() - 1) == NUM_TILES, nm, " last dim must be 7"); + }; + auto check_bf16_scratch = [&](aiter_tensor_t& t, const char* nm) { + AITER_CHECK(t.dtype() == AITER_DTYPE_bf16, nm, " scratch must be bf16"); + AITER_CHECK(t.is_contiguous(), nm, " scratch must be contiguous"); + AITER_CHECK(t.size(t.dim() - 1) == D_FULL, nm, " scratch last dim must be 512"); + }; + + check_nope(q_nope, "q_nope"); check_rope(q_rope, "q_rope"); check_scale(q_scale, "q_scale"); + check_nope(unified_kv_nope, "unified_kv_nope"); check_rope(unified_kv_rope, "unified_kv_rope"); check_scale(unified_kv_scale, "unified_kv_scale"); + check_nope(kv_nope, "kv_nope"); check_rope(kv_rope, "kv_rope"); check_scale(kv_scale, "kv_scale"); + check_bf16_scratch(q_bf16, "q_bf16"); check_bf16_scratch(unified_kv_bf16, "unified_kv_bf16"); check_bf16_scratch(kv_bf16, "kv_bf16"); + + HipDeviceGuard guard(q_nope.device_id); + const hipStream_t stream = aiter::getCurrentHIPStream(); + + auto rows_of = [&](aiter_tensor_t& t) { + long n = 1; + for(int i = 0; i < t.dim() - 1; ++i) n *= t.size(i); + return static_cast(n); + }; + + auto launch_dequant = [&](aiter_tensor_t& nope, aiter_tensor_t& rope, + aiter_tensor_t& scale, aiter_tensor_t& dst, int rows) { + if(rows == 0) return; + pa_v4_dequant_kargs dk{}; + dk.nope_ptr = nope.data_ptr(); + dk.rope_ptr = rope.data_ptr(); + dk.scale_ptr = reinterpret_cast(scale.data_ptr()); + dk.out_ptr = dst.data_ptr(); + dk.rows = rows; + dk.stride_nope = D_NOPE; + dk.stride_rope = D_ROPE; + dk.stride_scale = NUM_TILES; + dk.stride_out = D_FULL; + dim3 grid(rows); + dim3 block(128); + pa_v4_fp8_dequant_kernel<<>>(dk); + HIP_CALL_LAUNCH(hipGetLastError()); + }; + + launch_dequant(q_nope, q_rope, q_scale, q_bf16, rows_of(q_nope)); + launch_dequant(unified_kv_nope, unified_kv_rope, unified_kv_scale, unified_kv_bf16, rows_of(unified_kv_nope)); + launch_dequant(kv_nope, kv_rope, kv_scale, kv_bf16, rows_of(kv_nope)); + + // Run the existing, proven bf16 attention on the dequantized scratch. + pa_sparse_prefill_opus_fwd(q_bf16, + unified_kv_bf16, + kv_indices_prefix, + kv_indptr_prefix, + kv_bf16, + kv_indices_extend, + kv_indptr_extend, + attn_sink, + out, + softmax_scale); +} + +// ============================================================================= +// FUSED fp8 entry: single-warp kernel reading fp8 KV directly (no bf16 scratch). +// ============================================================================= +void pa_sparse_prefill_opus_fp8_fused_fwd(aiter_tensor_t& q_nope, + aiter_tensor_t& q_rope, + aiter_tensor_t& q_scale, + aiter_tensor_t& unified_kv_nope, + aiter_tensor_t& unified_kv_rope, + aiter_tensor_t& unified_kv_scale, + aiter_tensor_t& kv_nope, + aiter_tensor_t& kv_rope, + aiter_tensor_t& kv_scale, + aiter_tensor_t& kv_indices_prefix, + aiter_tensor_t& kv_indptr_prefix, + aiter_tensor_t& kv_indices_extend, + aiter_tensor_t& kv_indptr_extend, + aiter_tensor_t& attn_sink, + aiter_tensor_t& out, + float softmax_scale) +{ + const int N = static_cast(q_nope.size(0)); + const int H = static_cast(q_nope.size(1)); + AITER_CHECK(q_nope.dtype() == AITER_DTYPE_fp8, "q_nope must be fp8"); + AITER_CHECK(q_rope.dtype() == AITER_DTYPE_bf16, "q_rope must be bf16"); + AITER_CHECK(out.dtype() == AITER_DTYPE_bf16, "out must be bf16"); + AITER_CHECK(H % 16 == 0, "H must be a multiple of 16 for the fused fp8 kernel, got H=", H); + AITER_CHECK(q_nope.size(2) == 448, "q_nope last dim must be 448"); + if (N == 0) return; + + pa_sparse_prefill_fp8_kargs kargs{}; + kargs.q_nope = q_nope.data_ptr(); + kargs.q_rope = q_rope.data_ptr(); + kargs.q_scale = reinterpret_cast(q_scale.data_ptr()); + kargs.ukv_nope = unified_kv_nope.data_ptr(); + kargs.ukv_rope = unified_kv_rope.data_ptr(); + kargs.ukv_scale = reinterpret_cast(unified_kv_scale.data_ptr()); + kargs.kv_nope = kv_nope.data_ptr(); + kargs.kv_rope = kv_rope.data_ptr(); + kargs.kv_scale = reinterpret_cast(kv_scale.data_ptr()); + kargs.attn_sink = attn_sink.data_ptr(); + kargs.out = out.data_ptr(); + kargs.kv_indptr_prefix = reinterpret_cast(kv_indptr_prefix.data_ptr()); + kargs.kv_indices_prefix = reinterpret_cast(kv_indices_prefix.data_ptr()); + kargs.kv_indptr_extend = reinterpret_cast(kv_indptr_extend.data_ptr()); + kargs.kv_indices_extend = reinterpret_cast(kv_indices_extend.data_ptr()); + kargs.N = N; + kargs.H = H; + kargs.total_pages = static_cast(unified_kv_nope.size(0)); + kargs.total_tokens = static_cast(kv_nope.size(0)); + kargs.softmax_scale = softmax_scale; + + HipDeviceGuard guard(q_nope.device_id); + const hipStream_t stream = aiter::getCurrentHIPStream(); + // v4 layout: 448 nope + 64 rope, e8m0 scale per 64-elem tile, KV tile 32, + // up to MAX_WARPS=4 independent head-tiles/block. + using Traits = pa_prefill_fp8_traits<16, 32, 448, 64, 64, 4>; + // Pack up to MAX_WARPS head-tiles per block (more resident waves to hide + // memory latency). nwarp must divide H/16 so every warp maps to a valid head. + const int htiles = H / Traits::QTILE; + int nwarp = 1; + for (int nw = Traits::MAX_WARPS; nw >= 1; --nw) { + if (htiles % nw == 0) { nwarp = nw; break; } + } + dim3 grid(N, htiles / nwarp, 1); + dim3 block(nwarp * Traits::WARP_SIZE); + pa_prefill_fp8_fused_kernel<<>>(kargs); + HIP_CALL_LAUNCH(hipGetLastError()); +} diff --git a/csrc/pybind/pa_sparse_prefill_opus_pybind.cu b/csrc/pybind/pa_sparse_prefill_opus_pybind.cu index 59726c4a38..eb7eed22a4 100644 --- a/csrc/pybind/pa_sparse_prefill_opus_pybind.cu +++ b/csrc/pybind/pa_sparse_prefill_opus_pybind.cu @@ -8,4 +8,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { AITER_SET_STREAM_PYBIND PA_SPARSE_PREFILL_OPUS_PYBIND; + PA_SPARSE_PREFILL_OPUS_FP8_PYBIND; + PA_SPARSE_PREFILL_OPUS_FP8_FUSED_PYBIND; } diff --git a/op_tests/bench_pa_sparse_prefill_opus_fp8.py b/op_tests/bench_pa_sparse_prefill_opus_fp8.py new file mode 100644 index 0000000000..916b4a6cc3 --- /dev/null +++ b/op_tests/bench_pa_sparse_prefill_opus_fp8.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +"""Benchmark: bf16 baseline vs fp8 dequant-prepass vs fp8 fused, same problem.""" + +from __future__ import annotations +import math +import sys +import pandas as pd +import torch + +import aiter # noqa +from aiter.ops.pa_sparse_prefill_opus import ( + pa_sparse_prefill_opus, + pa_sparse_prefill_opus_fp8_fused, +) +from aiter.test_common import run_perftest + +sys.path.insert(0, "op_tests") +from test_pa_sparse_prefill_opus_fp8 import _make_fp8_inputs # noqa: E402 +from pa_sparse_prefill_opus_fp8_quant import dequantize_v4_fp8, D_FULL # noqa: E402 + + +def bench_point(n, h, pages, tokens, mode): + inp = _make_fp8_inputs(n, h, pages, tokens, mode=mode, seed=0) + ss = 1.0 / math.sqrt(D_FULL) + # bf16 baseline operates on dequanted inputs. + q = dequantize_v4_fp8(inp["q_nope"], inp["q_rope"], inp["q_scale"]) + ukv = dequantize_v4_fp8(inp["ukv_nope"], inp["ukv_rope"], inp["ukv_scale"]) + kv = dequantize_v4_fp8(inp["kv_nope"], inp["kv_rope"], inp["kv_scale"]) + + _, us_bf16 = run_perftest( + pa_sparse_prefill_opus, + q, + ukv, + inp["kv_indices_prefix"], + inp["kv_indptr_prefix"], + kv, + inp["kv_indices_extend"], + inp["kv_indptr_extend"], + inp["attn_sink"], + ss, + num_iters=50, + num_warmup=5, + ) + _, us_fused = run_perftest( + pa_sparse_prefill_opus_fp8_fused, + inp["q_nope"], + inp["q_rope"], + inp["q_scale"], + inp["ukv_nope"], + inp["ukv_rope"], + inp["ukv_scale"], + inp["kv_nope"], + inp["kv_rope"], + inp["kv_scale"], + inp["kv_indices_prefix"], + inp["kv_indptr_prefix"], + inp["kv_indices_extend"], + inp["kv_indptr_extend"], + inp["attn_sink"], + ss, + num_iters=50, + num_warmup=5, + ) + nnz = int(inp["kv_indices_prefix"].numel()) + int(inp["kv_indices_extend"].numel()) + return { + "N": n, + "H": h, + "mode": mode, + "nnz": nnz, + "bf16_us": round(us_bf16, 1), + "fused_us": round(us_fused, 1), + "fused/bf16": round(us_fused / us_bf16, 2), + } + + +if __name__ == "__main__": + arch = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0] + if arch != "gfx950": + print(f"SKIP: needs gfx950, got {arch}") + sys.exit(0) + rows = [] + for n, h, pages, tokens in [ + (1024, 128, 4096, 1024), + (4096, 128, 16384, 4096), + (1024, 16, 4096, 1024), + ]: + for mode in ["dense", "sparse"]: + rows.append(bench_point(n, h, pages, tokens, mode)) + print(rows[-1]) + print() + print(pd.DataFrame(rows).to_string(index=False)) diff --git a/op_tests/opus/fp8_lds_probe.cc b/op_tests/opus/fp8_lds_probe.cc new file mode 100644 index 0000000000..d16d47e5f5 --- /dev/null +++ b/op_tests/opus/fp8_lds_probe.cc @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Approach-A foundation: validate that the 16mx8 traits/layouts instantiate with +// D_ATTN=fp8_t and that async_load(fp8)->fp8 LDS, read via u_rk, reproduces the +// SAME logical K as the bf16 path. Uses small-integer data (exact in fp8 & bf16) +// so this tests the LAYOUT (not quantization). PASS => fp8 LDS round-trip works, +// so Approach A (fp8 in LDS + dequant at read) is feasible by instantiation. +// +// Build (docker gfx950): +// hipcc -std=c++20 --offload-arch=gfx950 -O2 -I csrc/include op_tests/opus/fp8_lds_probe.cc -o /tmp/flp && /tmp/flp +#define PA_SPARSE_PREFILL_OPUS_IMPL +#include "pa_sparse_prefill_opus.h" +#include "opus/opus.hpp" +#ifndef __HIP_DEVICE_COMPILE__ +#include "opus/hip_minimal.hpp" +#include +#include +#include +#endif + +using Tbf = pa_prefill_16mx8_32nx1_traits<16, 32, 512, 8, bf16_t>; +using Tfp8 = pa_prefill_16mx8_fp8_traits<16, 32, 512, 8, unsigned char, bf16_t>; +constexpr int NROW = 32, D = 512; + +// read K slice0 (u_rk, no skv offset) for each lane, convert to float, write [64*ELEM] +__global__ void readK(const bf16_t* g_bf, const unsigned char* g_fp8, + const int* kvidx, float* out_bf, float* out_fp8) +{ +#if defined(__gfx950__) + using namespace opus; + using namespace pa_16mx8_32nx1; + int lane = thread_id_x() % 64; + int warp = __builtin_amdgcn_readfirstlane(thread_id_x() / 64); + + // ---- bf16 path ---- + { + using T = Tbf; + __shared__ char sm[T::smem_kv_tile_elems * sizeof(bf16_t)]; + auto s = make_smem(reinterpret_cast(sm)); + auto g = make_gmem(g_bf, NROW * D * (int)sizeof(bf16_t)); + auto u_g = make_layout_gkv(warp, lane); + auto u_s = make_layout_skv(warp); + auto u_kvi = make_layout_kv_indices(warp, lane); + auto g_kvi = make_gmem(kvidx, NROW * (int)sizeof(int)); + int kv_page = load(g_kvi, u_kvi, 0)[0]; + async_load(g, s.ptr, u_g + kv_page * D, u_s); + s_waitcnt_vmcnt(0_I); __builtin_amdgcn_s_barrier(); + auto u_rk = make_layout_rk(lane); + auto vk = load(s, u_rk); + constexpr int N = vector_traits::size(); + for (int e = 0; e < N; ++e) out_bf[thread_id_x() * N + e] = (float)vk[e]; + } + __builtin_amdgcn_s_barrier(); + // ---- fp8 path ---- + { + using T = Tfp8; + using D_KV = typename T::D_KV; + __shared__ char sm[T::smem_kv_tile_elems * sizeof(D_KV)]; + auto s = make_smem(reinterpret_cast(sm)); + auto g = make_gmem(g_fp8, NROW * D * (int)sizeof(D_KV)); + auto u_g = make_layout_gkv(warp, lane); + auto u_s = make_layout_skv(warp); + auto u_kvi = make_layout_kv_indices(warp, lane); + auto g_kvi = make_gmem(kvidx, NROW * (int)sizeof(int)); + int kv_page = load(g_kvi, u_kvi, 0)[0]; + async_load(g, s.ptr, u_g + kv_page * D, u_s); + s_waitcnt_vmcnt(0_I); __builtin_amdgcn_s_barrier(); + auto u_rk = make_layout_rk(lane); + auto vk = load(s, u_rk); + constexpr int N = vector_traits::size(); + for (int e = 0; e < N; ++e) { + float f = __builtin_amdgcn_cvt_f32_fp8((int)(unsigned char)vk[e], 0); + out_fp8[thread_id_x() * N + e] = f; + } + } +#endif +} + +#ifndef __HIP_DEVICE_COMPILE__ +static unsigned char i_to_e4m3(int v){ // exact for small ints in [-7..7]ish + if(v==0)return 0; int s=v<0?0x80:0; v=v<0?-v:v; + int e=0,x=v; while(x>1){x>>=1;e++;} int M=(int)(((float)v/(1<0.01f){mism++; if(d>maxd)maxd=d;} } + printf("u_rk K read: bf16-path vs fp8-path (exact-int data) mism=%d maxd=%.3f -> %s\n", + mism,maxd, mism==0?"PASS (fp8 LDS layout works)":"FAIL"); + return mism==0?0:1; +} +#endif diff --git a/op_tests/opus/qk_full_probe.cc b/op_tests/opus/qk_full_probe.cc new file mode 100644 index 0000000000..65cd40cf3c --- /dev/null +++ b/op_tests/opus/qk_full_probe.cc @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Validate the full QK first-GEMM in isolation: +// S[m,n] = NOPE(fp8 16x16x32 MFMA + per-64-tile sw scale) + ROPE(bf16 16x16x32 MFMA) +// vs a dequant CPU reference. This is the complete first GEMM the fused attention +// kernel will use ("mxfp8 nope, then bf16 mfma accumulate the rope part"). +// +// Build (docker, gfx950): +// hipcc -std=c++20 --offload-arch=gfx950 -O2 -I csrc/include op_tests/opus/qk_full_probe.cc -o /tmp/qkf && /tmp/qkf +#include "opus/opus.hpp" +#ifndef __HIP_DEVICE_COMPILE__ +#include "opus/hip_minimal.hpp" +#include +#include +#include +#endif + +using opus::fp8_t; +using opus::bf16_t; +using opus::fp32_t; +using fp8x8_t = opus::fp8x8_t; +using bf16x8_t = opus::bf16x8_t; +using fp32x4_t = opus::vector_t; + +constexpr int M = 16, N = 16, KD = 448, TILE = 64, NTILE = 7; +constexpr int RD = 64; // rope dim (bf16) + +__global__ void qk_full(const fp8_t* __restrict__ q8, // [M, KD] nope fp8 + const fp8_t* __restrict__ k8, // [N, KD] nope fp8 + const float* __restrict__ sq, // [M, NTILE] + const float* __restrict__ sk, // [N, NTILE] + const bf16_t* __restrict__ qr, // [M, RD] rope bf16 + const bf16_t* __restrict__ kr, // [N, RD] rope bf16 + float* __restrict__ S) // [M, N] +{ +#if defined(__gfx950__) + using namespace opus; + int lane = (int)__builtin_amdgcn_workitem_id_x(); + int m = lane % 16; + int kblk = lane / 16; // 0..3 + int n = lane % 16; + float s[4] = {0.f, 0.f, 0.f, 0.f}; + + // --- NOPE: fp8 MFMA + software per-64-tile scale --- + for (int t = 0; t < NTILE; ++t) { + fp32x4_t vc{0.f, 0.f, 0.f, 0.f}; + for (int kk = 0; kk < 2; ++kk) { + int kbase = t * TILE + kk * 32; + fp8x8_t a_reg, b_reg; +#pragma unroll + for (int j = 0; j < 8; ++j) { + a_reg[j] = q8[m * KD + kbase + kblk * 8 + j]; + b_reg[j] = k8[n * KD + kbase + kblk * 8 + j]; + } + vc = mfma{}(a_reg, b_reg, vc); + } +#pragma unroll + for (int i = 0; i < 4; ++i) + s[i] += vc[i] * sq[(kblk * 4 + i) * NTILE + t] * sk[n * NTILE + t]; + } + + // --- ROPE: bf16 MFMA (no scale), accumulate into same S --- + { + fp32x4_t vc{0.f, 0.f, 0.f, 0.f}; + for (int kk = 0; kk < 2; ++kk) { // 2 x 32 = 64 rope + int kbase = kk * 32; + bf16x8_t a_reg, b_reg; +#pragma unroll + for (int j = 0; j < 8; ++j) { + a_reg[j] = qr[m * RD + kbase + kblk * 8 + j]; + b_reg[j] = kr[n * RD + kbase + kblk * 8 + j]; + } + vc = mfma{}(a_reg, b_reg, vc); + } +#pragma unroll + for (int i = 0; i < 4; ++i) s[i] += vc[i]; + } + +#pragma unroll + for (int i = 0; i < 4; ++i) S[(kblk * 4 + i) * N + n] = s[i]; +#endif +} + +#ifndef __HIP_DEVICE_COMPILE__ +static float e4m3_decode(unsigned char b) { + int s = (b >> 7) & 1, e = (b >> 3) & 0xF, mant = b & 0x7; + float v; + if (e == 0) v = (float)mant / 8.0f * 0.015625f; + else v = (1.0f + (float)mant / 8.0f) * ldexpf(1.0f, e - 7); + return s ? -v : v; +} +static unsigned char rand_fp8() { + for (;;) { + unsigned char b = (unsigned char)(rand() & 0x3F); + int e = (b >> 3) & 0xF, mant = b & 0x7; + if (e == 15 && mant == 7) continue; + return b; + } +} +static unsigned short to_bf16(float f) { + unsigned int x; __builtin_memcpy(&x, &f, 4); + unsigned int round = ((x >> 16) & 1u) + 0x7FFFu; + return (unsigned short)(((x + round) >> 16) & 0xFFFFu); +} +static float from_bf16(unsigned short h) { + unsigned int x = ((unsigned int)h) << 16; float f; __builtin_memcpy(&f, &x, 4); return f; +} + +int main() { + srand(1234); + const int qn=M*KD, kn=N*KD, sqn=M*NTILE, skn=N*NTILE, qrn=M*RD, krn=N*RD; + unsigned char *hq=new unsigned char[qn], *hk=new unsigned char[kn]; + float *hsq=new float[sqn], *hsk=new float[skn]; + unsigned short *hqr=new unsigned short[qrn], *hkr=new unsigned short[krn]; + for (int i=0;i +#include +#include +#endif + +using opus::fp8_t; +using opus::fp32_t; +using fp8x8_t = opus::fp8x8_t; +using fp32x4_t = opus::vector_t; + +constexpr int M = 16, N = 16, KD = 448, TILE = 64, NTILE = 7; + +__global__ void qk_nope(const fp8_t* __restrict__ q8, // [M, KD] + const fp8_t* __restrict__ k8, // [N, KD] + const float* __restrict__ sq, // [M, NTILE] + const float* __restrict__ sk, // [N, NTILE] + float* __restrict__ S) // [M, N] +{ +#if defined(__gfx950__) + using namespace opus; + int lane = (int)__builtin_amdgcn_workitem_id_x(); + int m = lane % 16; // A-load row + int kblk = lane / 16; // 0..3, which 8-wide K sub-block this lane loads + int n = lane % 16; // B/output column + float s[4] = {0.f, 0.f, 0.f, 0.f}; + + for (int t = 0; t < NTILE; ++t) { + fp32x4_t vc{0.f, 0.f, 0.f, 0.f}; + for (int kk = 0; kk < 2; ++kk) { // 2 x 32 = 64 (one tile) + int kbase = t * TILE + kk * 32; + fp8x8_t a_reg, b_reg; +#pragma unroll + for (int j = 0; j < 8; ++j) { + a_reg[j] = q8[m * KD + kbase + kblk * 8 + j]; + b_reg[j] = k8[n * KD + kbase + kblk * 8 + j]; + } + vc = mfma{}(a_reg, b_reg, vc); + } + // vc[i] = partial Q.K over this 64-tile for C[m_out=kblk*4+i][n] +#pragma unroll + for (int i = 0; i < 4; ++i) { + int mo = kblk * 4 + i; + s[i] += vc[i] * sq[mo * NTILE + t] * sk[n * NTILE + t]; + } + } +#pragma unroll + for (int i = 0; i < 4; ++i) S[(kblk * 4 + i) * N + n] = s[i]; +#endif +} + +#ifndef __HIP_DEVICE_COMPILE__ +static float e4m3_decode(unsigned char b) { + int s = (b >> 7) & 1, e = (b >> 3) & 0xF, mant = b & 0x7; + float v; + if (e == 0) v = (float)mant / 8.0f * 0.015625f; // 2^-6 subnormal + else v = (1.0f + (float)mant / 8.0f) * ldexpf(1.0f, e - 7); + return s ? -v : v; +} +static unsigned char rand_fp8() { + // Positive-only (clear sign) so QK sums don't catastrophically cancel; this + // isolates the computation logic from fp32-vs-double rounding on near-zero + // results. Also cap exponent to keep magnitudes ~O(1) (realistic attention). + for (;;) { + unsigned char b = (unsigned char)(rand() & 0x3F); // sign=0, exp<=7 + int e = (b >> 3) & 0xF, mant = b & 0x7; + if (e == 15 && mant == 7) continue; // skip NaN (e4m3fn) + return b; + } +} + +int main() { + srand(1234); + const int qn = M * KD, kn = N * KD, sqn = M * NTILE, skn = N * NTILE; + unsigned char *hq = new unsigned char[qn], *hk = new unsigned char[kn]; + float *hsq = new float[sqn], *hsk = new float[skn]; + for (int i = 0; i < qn; ++i) hq[i] = rand_fp8(); + for (int i = 0; i < kn; ++i) hk[i] = rand_fp8(); + // scales: random powers of two 2^{-2..2} + for (int i = 0; i < sqn; ++i) hsq[i] = ldexpf(1.0f, (rand() % 5) - 2); + for (int i = 0; i < skn; ++i) hsk[i] = ldexpf(1.0f, (rand() % 5) - 2); + + // CPU reference: dequant then Q@K^T + float* ref = new float[M * N]; + for (int m = 0; m < M; ++m) + for (int n = 0; n < N; ++n) { + double acc = 0; + for (int t = 0; t < NTILE; ++t) { + double part = 0; + for (int d = t * TILE; d < (t + 1) * TILE; ++d) + part += (double)e4m3_decode(hq[m * KD + d]) * (double)e4m3_decode(hk[n * KD + d]); + acc += part * (double)hsq[m * NTILE + t] * (double)hsk[n * NTILE + t]; + } + ref[m * N + n] = (float)acc; + } + + void *dq, *dk; float *dsq, *dsk, *dS; + hipMalloc(&dq, qn); hipMalloc(&dk, kn); + hipMalloc(&dsq, sqn * 4); hipMalloc(&dsk, skn * 4); hipMalloc(&dS, M * N * 4); + hipMemcpy(dq, hq, qn, hipMemcpyHostToDevice); + hipMemcpy(dk, hk, kn, hipMemcpyHostToDevice); + hipMemcpy(dsq, hsq, sqn * 4, hipMemcpyHostToDevice); + hipMemcpy(dsk, hsk, skn * 4, hipMemcpyHostToDevice); + hipLaunchKernelGGL(qk_nope, dim3(1), 64, 0, 0, + (const fp8_t*)dq, (const fp8_t*)dk, dsq, dsk, dS); + hipDeviceSynchronize(); + float* hS = new float[M * N]; + hipMemcpy(hS, dS, M * N * 4, hipMemcpyDeviceToHost); + + double max_abs = 0, max_rel = 0; + for (int i = 0; i < M * N; ++i) { + double d = fabs((double)hS[i] - (double)ref[i]); + max_abs = fmax(max_abs, d); + max_rel = fmax(max_rel, d / (fabs((double)ref[i]) + 1e-6)); + } + printf("QK-nope fp8-MFMA + sw-scale vs dequant ref: max_abs=%.5f max_rel=%.5f\n", max_abs, max_rel); + printf("sample S[0][0..3] gpu=[%.3f %.3f %.3f %.3f] ref=[%.3f %.3f %.3f %.3f]\n", + hS[0], hS[1], hS[2], hS[3], ref[0], ref[1], ref[2], ref[3]); + printf("%s\n", (max_rel < 1e-3) ? "PASS" : "FAIL"); + return (max_rel < 1e-3) ? 0 : 1; +} +#endif diff --git a/op_tests/pa_sparse_prefill_opus_fp8_quant.py b/op_tests/pa_sparse_prefill_opus_fp8_quant.py new file mode 100644 index 0000000000..a42aac60a0 --- /dev/null +++ b/op_tests/pa_sparse_prefill_opus_fp8_quant.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""FP8 quant/dequant helpers for the OPUS sparse prefill v4 (asm-v4-style) path. + +Mirrors the DeepSeek-V4 mixed-precision head-dim layout used by the asm v4 MLA +decode kernel (aiter PR #3112, op_tests/test_mla_v4_nm.py): + + head dim D = 512 = NOPE(448, FP8 e4m3) + ROPE(64, BF16) + +The 448 NOPE elements are split into 7 tiles of 64; each tile carries one +**e8m0** (power-of-two) scale, computed as ``2^ceil(log2(amax/fp8_max))`` and +applied on dequant as ``bf16 = nope_fp8.to(f32) * scale``. The 64 ROPE elements +stay BF16 and are never quantized. + +Design note vs asm v4: asm v4 embeds the 7 scales (duplicated x2 -> 14 bytes) +inside the 512-byte FP8 token record so its prebuilt .co can read them from +fixed offsets. Here we instead emit a **separate fp32 scale tensor** of shape +``[*, 7]`` (the values are still e8m0-rounded, i.e. exact powers of two). The +math is identical; a fresh HIP kernel is much cleaner consuming a separate +scale tensor than extracting bytes interleaved in the data buffer. +""" + +from __future__ import annotations + +import torch + +# Layout constants (DeepSeek-V4 / asm-v4). +D_FULL = 512 +D_NOPE = 448 # FP8-quantized +D_ROPE = 64 # BF16, never quantized +TILE = 64 # NOPE elements sharing one e8m0 scale +NUM_TILES = D_NOPE // TILE # 7 +assert D_NOPE + D_ROPE == D_FULL +assert D_NOPE % TILE == 0 + + +def _fp8_dtype() -> torch.dtype: + # gfx950 = OCP e4m3fn. (gfx942 would be e4m3fnuz; tests run on gfx950.) + return torch.float8_e4m3fn + + +def cast_scale_to_ue8m0(scale: torch.Tensor) -> torch.Tensor: + """Round a positive fp32 scale up to the nearest power of two (e8m0 grid). + + Matches test_mla_v4_nm.py::_cast_scale_inv_to_ue8m0 : + 2 ** ceil(log2(clamp_min(scale, 1e-4))) + """ + return torch.pow(2.0, torch.clamp_min(scale, 1e-4).log2().ceil()).to(torch.float32) + + +def quantize_to_v4_fp8(x_bf16: torch.Tensor): + """BF16 ``[..., 512]`` -> (nope_fp8 ``[..., 448]``, rope_bf16 ``[..., 64]``, + scale_fp32 ``[..., 7]``). + + Per 64-tile of NOPE: ``scale = ue8m0(amax / fp8_max)``, + ``nope_fp8 = (tile / scale).to(fp8)``. ROPE is the trailing 64 dims, kept BF16. + """ + assert ( + x_bf16.shape[-1] == D_FULL + ), f"expected last dim {D_FULL}, got {x_bf16.shape[-1]}" + fp8 = _fp8_dtype() + fp8_max = torch.finfo(fp8).max + + nope = x_bf16[..., :D_NOPE] + rope = x_bf16[..., D_NOPE:].to(torch.bfloat16).contiguous() + + leading = x_bf16.shape[:-1] + nope_tiled = nope.reshape(*leading, NUM_TILES, TILE).float() + amax = nope_tiled.abs().amax(dim=-1) # [..., 7] + scale = cast_scale_to_ue8m0(amax / fp8_max) # [..., 7] power-of-two + nope_q = (nope_tiled / scale.unsqueeze(-1)).to(fp8) # [..., 7, 64] + nope_q = nope_q.reshape(*leading, D_NOPE).contiguous() + return nope_q, rope, scale.contiguous() + + +def dequantize_v4_fp8( + nope_fp8: torch.Tensor, # [..., 448] fp8 + rope_bf16: torch.Tensor, # [..., 64] bf16 + scale_fp32: torch.Tensor, # [..., 7] fp32 (power-of-two) + out_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Inverse of :func:`quantize_to_v4_fp8`. Returns ``[..., 512]`` (out_dtype).""" + leading = nope_fp8.shape[:-1] + nope = nope_fp8.reshape(*leading, NUM_TILES, TILE).float() + nope = nope * scale_fp32.unsqueeze(-1) # broadcast tile scale + nope = nope.reshape(*leading, D_NOPE) + full = torch.cat([nope, rope_bf16.float()], dim=-1) # [..., 512] + return full.to(out_dtype) + + +if __name__ == "__main__": + # Roundtrip + dequant sanity (CPU; no GPU needed). + torch.manual_seed(0) + x = torch.randn(5, 3, D_FULL, dtype=torch.bfloat16) + nope_q, rope, scale = quantize_to_v4_fp8(x) + assert nope_q.shape == (5, 3, D_NOPE) and nope_q.dtype == _fp8_dtype() + assert rope.shape == (5, 3, D_ROPE) and rope.dtype == torch.bfloat16 + assert scale.shape == (5, 3, NUM_TILES) and scale.dtype == torch.float32 + # scales are exact powers of two + assert torch.all((scale.log2().round() - scale.log2()).abs() < 1e-5) + deq = dequantize_v4_fp8(nope_q, rope, scale) + # rope is lossless (bf16->bf16) + assert torch.allclose( + deq[..., D_NOPE:], x[..., D_NOPE:].to(torch.float32).to(torch.bfloat16) + ) + # nope within fp8 quant error of the per-tile range + err = (deq[..., :D_NOPE].float() - x[..., :D_NOPE].float()).abs() + rel = err / (x[..., :D_NOPE].float().abs() + 1e-3) + print(f"nope dequant: max_abs_err={err.max():.4f} mean_rel={rel.mean():.4f}") + print("roundtrip OK") diff --git a/op_tests/test_pa_sparse_prefill_opus_fp8.py b/op_tests/test_pa_sparse_prefill_opus_fp8.py new file mode 100644 index 0000000000..26cddfcd73 --- /dev/null +++ b/op_tests/test_pa_sparse_prefill_opus_fp8.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +r"""FP8 (DeepSeek-V4 / asm-v4 layout) path for OPUS sparse paged prefill attention. + +Head dim D = 512 = NOPE(448, FP8 e4m3, per-64-tile e8m0 scale) + ROPE(64, BF16), +for both Q and KV, mirroring aiter PR #3112 (asm v4 MLA decode). + +This test drives the *correctness oracle* and the data path end-to-end: + + bf16 ground-truth q/kv --quantize_to_v4_fp8--> (nope fp8, rope bf16, scale fp32) + | | + v (fp8-dequant reference) v (kernel-under-test) + dequantize_v4_fp8 -> existing fp32 SDPA reference dequant -> attention kernel + \________________ checkAllclose ________________/ + +Milestone v1 (this file): the "kernel-under-test" dequants on-GPU and runs the +*existing, proven* bf16 ``pa_sparse_prefill_opus``. This validates the v4 fp8 +format, the scales, and the reference wiring on real gfx950 hardware before the +fused FP8 ``__global__`` lands. When the fused kernel is ready, only +``_run_kernel_under_test`` changes (swap in ``pa_sparse_prefill_opus_fp8``). +""" + +from __future__ import annotations + +import argparse +import itertools +import math +import sys +from typing import Optional + +import pandas as pd +import pytest +import torch + +import aiter # noqa: F401 +from aiter.ops.pa_sparse_prefill_opus import ( + pa_sparse_prefill_opus_fp8, + pa_sparse_prefill_opus_fp8_fused, +) +from aiter.test_common import benchmark, checkAllclose + +# Reuse the bf16 test's validated reference + CSR generators + skip helpers. +from test_pa_sparse_prefill_opus import ( # noqa: E402 + _dense_csr, + _empty_csr, + _random_csr, + _ref_pa_sparse_prefill_opus, + _skip_if_unsupported, +) +from pa_sparse_prefill_opus_fp8_quant import ( # noqa: E402 + D_FULL, + dequantize_v4_fp8, + quantize_to_v4_fp8, +) + +_MODES = ("sparse", "dense", "empty") + + +def _make_fp8_inputs( + n: int, + h: int, + total_pages: int, + total_tokens: int, + *, + mode: str = "sparse", + device: torch.device | str = "cuda", + seed: int = 0, +) -> dict: + """Build bf16 ground-truth q/unified_kv/kv, quantize to the v4 fp8 layout, + and return everything both the reference (dequant) and the kernel need. + + Returns a dict with: + bf16 ground truth: q_bf16, unified_kv_bf16, kv_bf16 + fp8 packed: q_nope/q_rope/q_scale, ukv_nope/ukv_rope/ukv_scale, + kv_nope/kv_rope/kv_scale + csr + sink: kv_indices/indptr (prefix+extend), attn_sink + """ + assert mode in _MODES + torch.manual_seed(seed) + device = torch.device(device) + + q = (torch.randn(n, h, D_FULL, device=device, dtype=torch.float32) * 0.5).to( + torch.bfloat16 + ) + unified_kv = ( + torch.randn(total_pages, D_FULL, device=device, dtype=torch.float32) * 0.5 + ).to(torch.bfloat16) + kv = ( + torch.randn(total_tokens, D_FULL, device=device, dtype=torch.float32) * 0.5 + ).to(torch.bfloat16) + attn_sink = torch.randn(h, device=device, dtype=torch.float32) * 0.25 + + def _csr(total_rows: int, seed_offset: int): + if mode == "sparse": + return _random_csr( + n, total_rows, device=device, seed=seed * 2 + seed_offset + ) + if mode == "dense": + return _dense_csr(n, total_rows, device=device) + return _empty_csr(n, device=device) + + ip_p, ix_p = _csr(total_pages, 1) + ip_e, ix_e = _csr(total_tokens, 2) + + q_nope, q_rope, q_scale = quantize_to_v4_fp8(q) + ukv_nope, ukv_rope, ukv_scale = quantize_to_v4_fp8(unified_kv) + kv_nope, kv_rope, kv_scale = quantize_to_v4_fp8(kv) + + return dict( + q_bf16=q, + unified_kv_bf16=unified_kv, + kv_bf16=kv, + q_nope=q_nope, + q_rope=q_rope, + q_scale=q_scale, + ukv_nope=ukv_nope, + ukv_rope=ukv_rope, + ukv_scale=ukv_scale, + kv_nope=kv_nope, + kv_rope=kv_rope, + kv_scale=kv_scale, + kv_indices_prefix=ix_p, + kv_indptr_prefix=ip_p, + kv_indices_extend=ix_e, + kv_indptr_extend=ip_e, + attn_sink=attn_sink, + ) + + +def _ref_from_fp8(inp: dict, softmax_scale: float) -> torch.Tensor: + """Dequantize the fp8 tensors the kernel sees, then run the validated + fp32 SDPA reference. Isolates 'kernel math' from 'fp8 quant noise'.""" + q = dequantize_v4_fp8(inp["q_nope"], inp["q_rope"], inp["q_scale"]) + ukv = dequantize_v4_fp8(inp["ukv_nope"], inp["ukv_rope"], inp["ukv_scale"]) + kv = dequantize_v4_fp8(inp["kv_nope"], inp["kv_rope"], inp["kv_scale"]) + return _ref_pa_sparse_prefill_opus( + q=q, + unified_kv=ukv, + kv_indices_prefix=inp["kv_indices_prefix"], + kv_indptr_prefix=inp["kv_indptr_prefix"], + kv=kv, + kv_indices_extend=inp["kv_indices_extend"], + kv_indptr_extend=inp["kv_indptr_extend"], + attn_sink=inp["attn_sink"], + softmax_scale=softmax_scale, + ) + + +def _run_kernel_under_test(inp: dict, softmax_scale: float) -> torch.Tensor: + """The fp8 device op: on-GPU dequant (standalone kernel) -> bf16 attention.""" + return pa_sparse_prefill_opus_fp8( + q_nope=inp["q_nope"], + q_rope=inp["q_rope"], + q_scale=inp["q_scale"], + unified_kv_nope=inp["ukv_nope"], + unified_kv_rope=inp["ukv_rope"], + unified_kv_scale=inp["ukv_scale"], + kv_nope=inp["kv_nope"], + kv_rope=inp["kv_rope"], + kv_scale=inp["kv_scale"], + kv_indices_prefix=inp["kv_indices_prefix"], + kv_indptr_prefix=inp["kv_indptr_prefix"], + kv_indices_extend=inp["kv_indices_extend"], + kv_indptr_extend=inp["kv_indptr_extend"], + attn_sink=inp["attn_sink"], + softmax_scale=softmax_scale, + ) + + +@benchmark() +def run_fp8( + n: int, + h: int, + total_pages: int, + total_tokens: int, + *, + mode: str = "sparse", + seed: int = 0, + verify: bool = True, +) -> Optional[dict]: + if _skip_if_unsupported(d=D_FULL): + return None + inp = _make_fp8_inputs(n, h, total_pages, total_tokens, mode=mode, seed=seed) + softmax_scale = 1.0 / math.sqrt(D_FULL) + + row = { + "nnz_prefix": int(inp["kv_indices_prefix"].numel()), + "nnz_extend": int(inp["kv_indices_extend"].numel()), + } + if verify: + ref = _ref_from_fp8(inp, softmax_scale) + got = _run_kernel_under_test(inp, softmax_scale) + # fp8 e4m3 tolerance (matches the v4 nm accuracy convention ~3e-2). + checkAllclose( + got, + ref, + rtol=4e-2, + atol=4e-2, + msg=f"[fp8 N={n} H={h} D={D_FULL} pages={total_pages} tokens={total_tokens} mode={mode}]", + ) + return row + + +_PYTEST_SHAPES = [ + (64, 16, 256, 256), + (128, 32, 256, 256), + (64, 64, 1024, 1024), + (256, 128, 2048, 2048), +] +_PYTEST_MODES = ["sparse", "dense", "empty"] + + +@pytest.mark.parametrize("mode", _PYTEST_MODES) +@pytest.mark.parametrize( + "n,h,total_pages,total_tokens", + _PYTEST_SHAPES, + ids=lambda v: "x".join(map(str, v)) if isinstance(v, tuple) else str(v), +) +def test_pa_sparse_prefill_opus_fp8(n, h, total_pages, total_tokens, mode): + run_fp8( + n=n, + h=h, + total_pages=total_pages, + total_tokens=total_tokens, + mode=mode, + seed=(hash((n, h, total_pages, total_tokens, mode)) & 0xFFFF), + verify=True, + ) + + +def _run_fused(inp: dict, softmax_scale: float) -> torch.Tensor: + return pa_sparse_prefill_opus_fp8_fused( + q_nope=inp["q_nope"], + q_rope=inp["q_rope"], + q_scale=inp["q_scale"], + unified_kv_nope=inp["ukv_nope"], + unified_kv_rope=inp["ukv_rope"], + unified_kv_scale=inp["ukv_scale"], + kv_nope=inp["kv_nope"], + kv_rope=inp["kv_rope"], + kv_scale=inp["kv_scale"], + kv_indices_prefix=inp["kv_indices_prefix"], + kv_indptr_prefix=inp["kv_indptr_prefix"], + kv_indices_extend=inp["kv_indices_extend"], + kv_indptr_extend=inp["kv_indptr_extend"], + attn_sink=inp["attn_sink"], + softmax_scale=softmax_scale, + ) + + +@pytest.mark.parametrize("mode", _PYTEST_MODES) +@pytest.mark.parametrize( + "n,h,total_pages,total_tokens", + _PYTEST_SHAPES, + ids=lambda v: "x".join(map(str, v)) if isinstance(v, tuple) else str(v), +) +def test_pa_sparse_prefill_opus_fp8_fused(n, h, total_pages, total_tokens, mode): + if _skip_if_unsupported(d=D_FULL): + return + inp = _make_fp8_inputs( + n, + h, + total_pages, + total_tokens, + mode=mode, + seed=(hash((n, h, total_pages, total_tokens, mode)) & 0xFFFF), + ) + softmax_scale = 1.0 / math.sqrt(D_FULL) + ref = _ref_from_fp8(inp, softmax_scale) + got = _run_fused(inp, softmax_scale) + checkAllclose( + got, + ref, + rtol=4e-2, + atol=4e-2, + msg=f"[fused fp8 N={n} H={h} pages={total_pages} tokens={total_tokens} mode={mode}]", + ) + + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("-n", "--n_tokens", type=int, nargs="*", default=[1024]) + p.add_argument("--h_q", type=int, nargs="*", default=[16, 32, 64, 128]) + p.add_argument("--total_pages", type=int, nargs="*", default=[4096]) + p.add_argument( + "--mode", type=str, nargs="*", default=["sparse", "dense"], choices=list(_MODES) + ) + p.add_argument("--seed", type=int, default=0) + args = p.parse_args() + rows = [] + for n, h, mode, pages in itertools.product( + args.n_tokens, args.h_q, args.mode, args.total_pages + ): + r = run_fp8( + n=n, h=h, total_pages=pages, total_tokens=n, mode=mode, seed=args.seed + ) + if r: + rows.append(r) + if rows: + print() + print(pd.DataFrame(rows).to_string(index=False)) + sys.exit(0)