[feat] FP8 (DeepSeek-V4 layout) sparse paged prefill attention#3583
Open
carlushuang wants to merge 7 commits into
Open
[feat] FP8 (DeepSeek-V4 layout) sparse paged prefill attention#3583carlushuang wants to merge 7 commits into
carlushuang wants to merge 7 commits into
Conversation
Add FP8 KV-cache support to the OPUS sparse paged prefill attention
(pa_sparse_prefill_opus) using the DeepSeek-V4 / asm-v4 mixed-precision
head-dim layout: D=512 = 448 NOPE (fp8 e4m3, per-64-tile e8m0 scale) +
64 ROPE (bf16), for both Q and KV.
Two new ops, both validated on gfx950 (48/48 incl. bf16 no-regression):
- pa_sparse_prefill_opus_fp8: dequant-prepass (standalone fp8->bf16 device
kernel, then the existing bf16 attention). ~1.04-1.15x bf16 while keeping
fp8 KV-cache storage; this is the performant path.
- pa_sparse_prefill_opus_fp8_fused: fully fused single-kernel mxfp8 attention
(fp8 QK-nope MFMA + software per-64-tile scale, bf16 QK-rope, bf16 PV with
on-chip V dequant, 4-way KV split prefix/extend x nope/rope, per-head sink,
no bf16 scratch). Correctness-first; ~26x bf16 (occupancy/latency-bound).
The bf16 kernel and all existing behavior are unchanged.
Files:
- csrc/include/pa_sparse_prefill_opus.h: dequant + fused __global__ kernels,
fp8 kargs, public API.
- csrc/py_itfs_cu/pa_sparse_prefill_opus_kernels.cu: fp8 host launchers.
- csrc/include/rocm_ops.hpp, csrc/pybind/pa_sparse_prefill_opus_pybind.cu:
pybind registration.
- aiter/ops/pa_sparse_prefill_opus.py: python ops + wrappers.
- op_tests/test_pa_sparse_prefill_opus_fp8.py,
op_tests/pa_sparse_prefill_opus_fp8_quant.py: oracle test + v4 fp8 quant.
- op_tests/bench_pa_sparse_prefill_opus_fp8.py: benchmark.
- op_tests/opus/qk_{nope,full}_probe.cc: standalone MFMA validation probes.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Black formatting + ruff fixes (remove unused imports, split semicolon statement) to pass the Black/Ruff CI checks. No functional change; fp8 tests still 24/24 on gfx950. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ion)
Make pa_prefill_fp8_fused_kernel a `template<class Traits>` __global__ with a
`pa_prefill_fp8_traits<Q_TILE, KV_TILE, D_NOPE, D_ROPE, KV_SCALE_TILE,
MAX_WARPS>` struct holding the compile-time config, matching the existing
bf16 pa_prefill_16mx{8,1}_* kernels (previously the fp8 kernel hardcoded
constants in a namespace). Launcher instantiates pa_prefill_fp8_traits<16,32,
448,64,64,4>. No behavior change; 48/48 tests pass on gfx950.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… aliases) Per convention (matches the bf16 pa_prefill_* kernels), reference the Traits members directly as T::QTILE / T::KVTILE / T::DNOPE / ... instead of aliasing them to a block of local `constexpr int`. No behavior change; 12/12 fused tests pass on gfx950. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
fp8x8_t / bf16x8_t (REGISTER_DTYPE-generated) and fp32x4_t already come from `using namespace opus;` in the kernel, so the pa_fp8_fused alias namespace was redundant. Use the opus types inline (matches the bf16 kernels). 12/12 fused tests pass on gfx950. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…+ bf16 MFMA) D_KV (fp8 e4m3 bytes) for LDS storage geometry, D_ATTN (bf16) for MFMA/Q/O. smem geometry derives from sizeof(D_KV)=1 (D_128B=128, etc.); VEC_KV=8 reads 8 fp8 -> dequant 8 bf16 at the smem->reg read. Foundation for the D_KV refactor (fuse fp8 KV into the 16mx8 pipeline). Unused yet; header compiles, 12/12 fp8. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The naive fp8-in-LDS path (reuse bf16 layouts with sizeof swapped) fails. Two
findings from op_tests/opus/fp8_lds_probe.cc on gfx950 (mi355-gpu-15):
1. fp8 traits smem geometry must mirror the bf16 ELEMENT layout (derive from
sizeof(D_ATTN), not sizeof(D_KV)): mma0 is bf16, fp8 is dequanted AFTER
placement, so the read must deliver the bf16 MFMA b-operand order. Fixed in
pa_prefill_16mx8_fp8_traits. Necessary but NOT sufficient.
2. Root cause: opus _async_load -> raw_ptr_buffer_load_lds on gfx950 only has
per-lane size branches {1,2,4,12,16} bytes (opus.hpp ~1690). fp8 VEC_KV=8 =>
8B/lane => no matching branch => silent no-op => LDS garbage => probe FAIL.
=> fp8 KV must use VEC_KV=16 (16B/lane DMA), i.e. a distinct gkv/skv/rk tiling
(16 fp8 d-elems/lane vs bf16's 8) whose u_rk read still yields the bf16-mma
b-operand. That layout redesign is the remaining crux; K dequant, Q dequant,
scalar-scale folding and smem geometry are already solved.
Probe + traits geometry fix only (standalone; traits unused by shipped kernels,
branch stays green).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds FP8 KV-cache support to the OPUS sparse paged prefill attention (
pa_sparse_prefill_opus) using the DeepSeek-V4 / asm-v4 mixed-precision head-dim layout:The existing bf16 kernel and all current behavior are unchanged.
What's added (two ops, both validated on gfx950)
pa_sparse_prefill_opus_fp8pa_sparse_prefill_opus_fp8_fused__global__: fp8 QK-nope MFMA + software per-64-tile e8m0 scale, bf16 QK-rope, bf16 PV with on-chip V dequant, 4-way KV split (prefix/extend × nope/rope), per-head sink, no bf16 scratchDesign notes
16×16×32MFMA + software per-64-tile scaling (theopus_gemm a8w8_scalepattern), validated against the canonicalop_tests/opus/device/test_mxfp.cu. This sidesteps the hardware MX scale-operand encoding and naturally handles the 64-element quant tiles (nodup×2needed).[*,7]tensor (e8m0-rounded powers of two) rather than asm-v4's in-buffer packed layout — identical math, much cleaner for a fresh HIP kernel.Validation (gfx950, rocm/atom 7.2.3 image)
op_tests/opus/qk_{nope,full}_probe.cc).Perf state (fused kernel)
The fused kernel is correct but ~26× off the hand-tuned bf16 baseline. Root cause (measured): VGPR-occupancy + memory-latency bound — incremental knobs (LDS reduction +2.6×, multi-warp +5–20%, larger tiles / forced occupancy = regressions). The dequant-prepass is the recommended performant path today; closing the fused kernel to parity requires porting the bf16 kernel's async double-buffered pipeline onto the validated fp8 QK path (future work).
Files
csrc/include/pa_sparse_prefill_opus.h— dequant + fused__global__kernels, fp8 kargs, public APIcsrc/py_itfs_cu/pa_sparse_prefill_opus_kernels.cu— fp8 host launcherscsrc/include/rocm_ops.hpp,csrc/pybind/pa_sparse_prefill_opus_pybind.cu— pybind registrationaiter/ops/pa_sparse_prefill_opus.py— python ops + wrappersop_tests/test_pa_sparse_prefill_opus_fp8.py,op_tests/pa_sparse_prefill_opus_fp8_quant.py— oracle test + v4 fp8 quant helpersop_tests/bench_pa_sparse_prefill_opus_fp8.py— benchmarkop_tests/opus/qk_{nope,full}_probe.cc— standalone MFMA validation probesTest plan
pytest op_tests/test_pa_sparse_prefill_opus_fp8.py(gfx950) — 24/24 fp8 passpytest op_tests/test_pa_sparse_prefill_opus.py(gfx950) — 24/24 bf16, no regression🤖 Generated with Claude Code