mha_native: native HIP D64 BF16 split-K forward backend for flash_attn_func#3581
Open
rocking5566 wants to merge 19 commits into
Open
mha_native: native HIP D64 BF16 split-K forward backend for flash_attn_func#3581rocking5566 wants to merge 19 commits into
rocking5566 wants to merge 19 commits into
Conversation
Comments in the vendored device headers referred to "the four existing
entries", "four call sites" and "[_varlen]" entry files that exist in the
upstream source but not on this branch (only the msk{0,1}_split producers
and combine ship here; fmha_fwd_d64_device is always instantiated with
IsSplit=true, IsVarlen=false). Reword to describe this branch's actual
entry set. Comments only; no code lines changed (ISA parity preserved).
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
… model Port the two-regime occupancy heuristic tuned on 100 measured D64 shapes, replacing the single hardcoded special case. G==0 falls back to the CK non-split-KV kernel. Thread seqlen_q through and detect CU count dynamically via get_cu_num() instead of assuming 304. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The heuristic now also returns 0 (CK fallback), so the dispatch falls through for ns <= 1, not just ns == 1. Comment only. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a new native HIP D64 BF16 split-K forward backend for aiter.flash_attn_func (gfx942-only) with a new num_splits knob and a heuristic selector, while leaving backward dispatch unchanged.
Changes:
- Introduces a native split-K forward implementation (producer + combine) and exposes it via pybind/JIT as
module_mha_fwd_native_splitkv. - Extends
flash_attn_func/ autograd plumbing with a newnum_splitsargument and a CU-count-based heuristic for choosing split count. - Updates MHA op tests to pass/parse the new
--num_splitsoption.
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_mha.py | Adds --num_splits CLI option and threads it into test/benchmark calls. |
| csrc/pybind/mha_fwd_native_splitkv_pybind.cu | Defines the pybind module entrypoint for the native split-K forward op. |
| csrc/kernels/mha_native/runner/params.hpp | Adds kernarg structs and kernel geometry constants for the native split-K implementation. |
| csrc/kernels/mha_native/mha_native.h | Declares the C++ API for mha_fwd_native_splitkv. |
| csrc/kernels/mha_native/mha_native_launch.h | Declares host launch wrappers for split producers and combine kernels. |
| csrc/kernels/mha_native/mha_fwd_native_splitkv.cu | Implements the Torch-facing C++ launcher (alloc scratch, launch producers + combine). |
| csrc/kernels/mha_native/fused/pipeline.hpp | Implements the fused forward pipeline with optional split-K narrowing + scratch epilogue. |
| csrc/kernels/mha_native/fused/op_softmax.hpp | Implements online softmax masking/max/sum/exp2 for the pipeline. |
| csrc/kernels/mha_native/fused/op_lds.hpp | Implements DRAM↔LDS staging (async K copy, V load+permute+store) and sync helpers. |
| csrc/kernels/mha_native/fused/op_gemm.hpp | Implements GEMM0 (Q·Kᵀ) and GEMM1 (P·V) using MFMA with required layout/swizzle. |
| csrc/kernels/mha_native/fused/op_epilog.hpp | Implements bf16 epilogue and fp32 split scratch epilogue. |
| csrc/kernels/mha_native/fused/op_combine.hpp | Implements the split-K combine pass (convex-combination reweighting + bf16 store). |
| csrc/kernels/mha_native/fmha_fwd_d64_bf16_msk1_split.cu | Adds causal split producer kernel entry + host launcher. |
| csrc/kernels/mha_native/fmha_fwd_d64_bf16_msk0_split.cu | Adds non-causal split producer kernel entry + host launcher. |
| csrc/kernels/mha_native/fmha_fwd_d64_bf16_combine.cu | Adds combine kernel entry + host launcher. |
| csrc/include/rocm_ops.hpp | Exposes mha_fwd_native_splitkv via a new pybind macro. |
| aiter/ops/mha.py | Adds JIT binding for native split-K, capability gating, heuristic, and public num_splits plumbing. |
| aiter/jit/optCompilerConfig.json | Registers the new JIT module build configuration and HIP flags/include paths. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mha_fwd_native_splitkv is a public aiter:: symbol exposed via pybind, so C++/direct callers bypass the Python-side can_impl_fmha_native gating. Validate k/v dtype (was q-only), 4-D rank, q/k/v last-dim contiguity, and Hq % Hk == 0 (Hk > Hq would divide by zero in device GQA grouping). Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…point out_opt is written via reinterpret_cast to bf16 using its own strides over a grid covering B*Hq*Sq*D, so a wrong dtype/device/shape silently corrupts memory or writes out of bounds. Validate bf16 dtype, same device as q, and (B,Sq,Hq,D) shape in addition to the existing last-dim contiguity check. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Reject num_splits < 0 explicitly; previously negatives silently fell through to the heuristic path like 0, masking misconfiguration. - Document num_splits in the flash_attn_func docstring (0=auto/heuristic, 1=disable split-K, >=2 forces native split-K when applicable). - Fix the test CLI --num_splits help text, which claimed ">=1 forces native" while the dispatch only routes to native when num_splits >= 2. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a third forward backend to
aiter.flash_attn_func: a hand-written native HIP D64 BF16 split-K kernel (2 producers + 1 combine), integrated via pybind as the JIT modulemodule_mha_fwd_native_splitkv. Forward-only; backward is unchanged (existing ASM/CK D64 bwd).The path is purely additive — it only engages on gfx942 + dense bf16 + D=64 when a capability gate passes, and otherwise falls through to the existing
fmha_v3_fwd/mha_fwddispatch untouched.What's included
can_impl_fmha_native): gfx942, bf16 q/k/v, D=64 (q & v), no bias/alibi/swa/dropout/sink/fp8/varlen,nhead_q % nhead_k == 0. Causal additionally requiresseqlen_k >= seqlen_q(avoids fully-masked-row NaN-vs-zero divergence).G ∈ {0,2,4,8,16};G==0falls back to the CK non-split path. CU count detected dynamically viaget_cu_num().num_splitsarg onflash_attn_func(0= heuristic,>=2forces native if capable), threaded through the autograd chain.-ns/--num_splitsadded toop_tests/test_mha.py.Performance (gfx942, MI300)
CK baseline (
G=0) vs heuristic-selected native split-KV:Group 4 (already saturated: b=8) shows the heuristic correctly declines to split (G=0), incurring no regression where split-K can't help. Largest wins are the undersubscribed long-KV shapes (groups 1–2) where split-K fills the machine.
Correctness
Verified via
op_tests/test_mha.pyagainstattention_refwithin bf16 tolerance:-ns).-ns) under kernel-trace; forward ref correctness.Device-ISA parity with the reference build verified (required
-mllvm -enable-post-misched=1on the module's hip flags);.amdhsadescriptors identical, packedv_pk_mul_f32restored.Test plan
module_mha_fwd_native_splitkv(JIT) on gfx942op_tests/test_mha.pytargeted D64 shapes pass (forced-nsand heuristic)🤖 Generated with Claude Code