Skip to content

[feat] FP8 (DeepSeek-V4 layout) sparse paged prefill attention#3583

Open
carlushuang wants to merge 7 commits into
mainfrom
carhuang/sparse_attn_f8_kv
Open

[feat] FP8 (DeepSeek-V4 layout) sparse paged prefill attention#3583
carlushuang wants to merge 7 commits into
mainfrom
carhuang/sparse_attn_f8_kv

Conversation

@carlushuang
Copy link
Copy Markdown
Collaborator

Summary

Adds FP8 KV-cache support to the OPUS sparse paged prefill attention (pa_sparse_prefill_opus) using the DeepSeek-V4 / asm-v4 mixed-precision head-dim layout:

D = 512 = 448 NOPE (fp8 e4m3, per-64-element-tile e8m0 scale) + 64 ROPE (bf16), for both Q and KV (prefix unified_kv + extend kv).

The existing bf16 kernel and all current behavior are unchanged.

What's added (two ops, both validated on gfx950)

Op Approach Perf Notes
pa_sparse_prefill_opus_fp8 dequant-prepass: a standalone fp8→bf16 device kernel writes transient bf16 scratch, then the existing fast bf16 attention runs ~1.04–1.15× bf16 The performant path. Persistent KV-cache stays fp8 (capacity preserved); only a per-call transient bf16 scratch.
pa_sparse_prefill_opus_fp8_fused fully fused single __global__: fp8 QK-nope MFMA + software per-64-tile e8m0 scale, bf16 QK-rope, bf16 PV with on-chip V dequant, 4-way KV split (prefix/extend × nope/rope), per-head sink, no bf16 scratch ~26× bf16 (H=128) Correctness-first; occupancy/latency-bound. Perf-rewrite path is characterized (see below).

Design notes

  • The mxfp8 QK uses the proven non-scaled fp8 16×16×32 MFMA + software per-64-tile scaling (the opus_gemm a8w8_scale pattern), validated against the canonical op_tests/opus/device/test_mxfp.cu. This sidesteps the hardware MX scale-operand encoding and naturally handles the 64-element quant tiles (no dup×2 needed).
  • Scales are emitted as a separate fp32 [*,7] tensor (e8m0-rounded powers of two) rather than asm-v4's in-buffer packed layout — identical math, much cleaner for a fresh HIP kernel.

Validation (gfx950, rocm/atom 7.2.3 image)

  • 48/48 tests pass: 24 bf16 (no regression) + 12 fp8 dequant-prepass + 12 fp8 fused — sparse/dense/empty × H∈{16,32,64,128}, vs an fp8-dequant reference within fp8 tolerance.
  • Standalone MFMA probes validate the fp8 QK math to ~5e-5 (op_tests/opus/qk_{nope,full}_probe.cc).

Perf state (fused kernel)

The fused kernel is correct but ~26× off the hand-tuned bf16 baseline. Root cause (measured): VGPR-occupancy + memory-latency bound — incremental knobs (LDS reduction +2.6×, multi-warp +5–20%, larger tiles / forced occupancy = regressions). The dequant-prepass is the recommended performant path today; closing the fused kernel to parity requires porting the bf16 kernel's async double-buffered pipeline onto the validated fp8 QK path (future work).

Files

  • csrc/include/pa_sparse_prefill_opus.h — dequant + fused __global__ kernels, fp8 kargs, public API
  • csrc/py_itfs_cu/pa_sparse_prefill_opus_kernels.cu — fp8 host launchers
  • csrc/include/rocm_ops.hpp, csrc/pybind/pa_sparse_prefill_opus_pybind.cu — pybind registration
  • aiter/ops/pa_sparse_prefill_opus.py — python ops + wrappers
  • op_tests/test_pa_sparse_prefill_opus_fp8.py, op_tests/pa_sparse_prefill_opus_fp8_quant.py — oracle test + v4 fp8 quant helpers
  • op_tests/bench_pa_sparse_prefill_opus_fp8.py — benchmark
  • op_tests/opus/qk_{nope,full}_probe.cc — standalone MFMA validation probes

Test plan

  • pytest op_tests/test_pa_sparse_prefill_opus_fp8.py (gfx950) — 24/24 fp8 pass
  • pytest op_tests/test_pa_sparse_prefill_opus.py (gfx950) — 24/24 bf16, no regression
  • CI on MI350 runner

🤖 Generated with Claude Code

Add FP8 KV-cache support to the OPUS sparse paged prefill attention
(pa_sparse_prefill_opus) using the DeepSeek-V4 / asm-v4 mixed-precision
head-dim layout: D=512 = 448 NOPE (fp8 e4m3, per-64-tile e8m0 scale) +
64 ROPE (bf16), for both Q and KV.

Two new ops, both validated on gfx950 (48/48 incl. bf16 no-regression):

- pa_sparse_prefill_opus_fp8: dequant-prepass (standalone fp8->bf16 device
  kernel, then the existing bf16 attention). ~1.04-1.15x bf16 while keeping
  fp8 KV-cache storage; this is the performant path.

- pa_sparse_prefill_opus_fp8_fused: fully fused single-kernel mxfp8 attention
  (fp8 QK-nope MFMA + software per-64-tile scale, bf16 QK-rope, bf16 PV with
  on-chip V dequant, 4-way KV split prefix/extend x nope/rope, per-head sink,
  no bf16 scratch). Correctness-first; ~26x bf16 (occupancy/latency-bound).

The bf16 kernel and all existing behavior are unchanged.

Files:
- csrc/include/pa_sparse_prefill_opus.h: dequant + fused __global__ kernels,
  fp8 kargs, public API.
- csrc/py_itfs_cu/pa_sparse_prefill_opus_kernels.cu: fp8 host launchers.
- csrc/include/rocm_ops.hpp, csrc/pybind/pa_sparse_prefill_opus_pybind.cu:
  pybind registration.
- aiter/ops/pa_sparse_prefill_opus.py: python ops + wrappers.
- op_tests/test_pa_sparse_prefill_opus_fp8.py,
  op_tests/pa_sparse_prefill_opus_fp8_quant.py: oracle test + v4 fp8 quant.
- op_tests/bench_pa_sparse_prefill_opus_fp8.py: benchmark.
- op_tests/opus/qk_{nope,full}_probe.cc: standalone MFMA validation probes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@carlushuang carlushuang requested a review from a team June 7, 2026 04:55
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 7, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3583 --add-label <label>

carlushuang and others added 6 commits June 7, 2026 05:00
Black formatting + ruff fixes (remove unused imports, split semicolon
statement) to pass the Black/Ruff CI checks. No functional change; fp8
tests still 24/24 on gfx950.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ion)

Make pa_prefill_fp8_fused_kernel a `template<class Traits>` __global__ with a
`pa_prefill_fp8_traits<Q_TILE, KV_TILE, D_NOPE, D_ROPE, KV_SCALE_TILE,
MAX_WARPS>` struct holding the compile-time config, matching the existing
bf16 pa_prefill_16mx{8,1}_* kernels (previously the fp8 kernel hardcoded
constants in a namespace). Launcher instantiates pa_prefill_fp8_traits<16,32,
448,64,64,4>. No behavior change; 48/48 tests pass on gfx950.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… aliases)

Per convention (matches the bf16 pa_prefill_* kernels), reference the Traits
members directly as T::QTILE / T::KVTILE / T::DNOPE / ... instead of aliasing
them to a block of local `constexpr int`. No behavior change; 12/12 fused
tests pass on gfx950.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
fp8x8_t / bf16x8_t (REGISTER_DTYPE-generated) and fp32x4_t already come from
`using namespace opus;` in the kernel, so the pa_fp8_fused alias namespace was
redundant. Use the opus types inline (matches the bf16 kernels). 12/12 fused
tests pass on gfx950.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…+ bf16 MFMA)

D_KV (fp8 e4m3 bytes) for LDS storage geometry, D_ATTN (bf16) for MFMA/Q/O.
smem geometry derives from sizeof(D_KV)=1 (D_128B=128, etc.); VEC_KV=8 reads
8 fp8 -> dequant 8 bf16 at the smem->reg read. Foundation for the D_KV refactor
(fuse fp8 KV into the 16mx8 pipeline). Unused yet; header compiles, 12/12 fp8.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The naive fp8-in-LDS path (reuse bf16 layouts with sizeof swapped) fails. Two
findings from op_tests/opus/fp8_lds_probe.cc on gfx950 (mi355-gpu-15):

1. fp8 traits smem geometry must mirror the bf16 ELEMENT layout (derive from
   sizeof(D_ATTN), not sizeof(D_KV)): mma0 is bf16, fp8 is dequanted AFTER
   placement, so the read must deliver the bf16 MFMA b-operand order. Fixed in
   pa_prefill_16mx8_fp8_traits. Necessary but NOT sufficient.

2. Root cause: opus _async_load -> raw_ptr_buffer_load_lds on gfx950 only has
   per-lane size branches {1,2,4,12,16} bytes (opus.hpp ~1690). fp8 VEC_KV=8 =>
   8B/lane => no matching branch => silent no-op => LDS garbage => probe FAIL.

=> fp8 KV must use VEC_KV=16 (16B/lane DMA), i.e. a distinct gkv/skv/rk tiling
   (16 fp8 d-elems/lane vs bf16's 8) whose u_rk read still yields the bf16-mma
   b-operand. That layout redesign is the remaining crux; K dequant, Q dequant,
   scalar-scale folding and smem geometry are already solved.

Probe + traits geometry fix only (standalone; traits unused by shipped kernels,
branch stays green).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant