Skip to content

mha_native: native HIP D64 BF16 split-K forward backend for flash_attn_func#3581

Open
rocking5566 wants to merge 19 commits into
mainfrom
native-fmha-splitkv
Open

mha_native: native HIP D64 BF16 split-K forward backend for flash_attn_func#3581
rocking5566 wants to merge 19 commits into
mainfrom
native-fmha-splitkv

Conversation

@rocking5566
Copy link
Copy Markdown
Contributor

@rocking5566 rocking5566 commented Jun 6, 2026

Summary

Adds a third forward backend to aiter.flash_attn_func: a hand-written native HIP D64 BF16 split-K kernel (2 producers + 1 combine), integrated via pybind as the JIT module module_mha_fwd_native_splitkv. Forward-only; backward is unchanged (existing ASM/CK D64 bwd).

The path is purely additive — it only engages on gfx942 + dense bf16 + D=64 when a capability gate passes, and otherwise falls through to the existing fmha_v3_fwd/mha_fwd dispatch untouched.

What's included

  • Kernel: native D64 BF16 split-K forward — producers write normalized fp32 partials (O_g, LSE_g) into split-major scratch; combine folds the G partials into the final BF16 O (+ optional LSE).
  • Capability gate (can_impl_fmha_native): gfx942, bf16 q/k/v, D=64 (q & v), no bias/alibi/swa/dropout/sink/fp8/varlen, nhead_q % nhead_k == 0. Causal additionally requires seqlen_k >= seqlen_q (avoids fully-masked-row NaN-vs-zero divergence).
  • Split-count heuristic: two-regime occupancy model tuned on 100 measured D64 shapes. Returns G ∈ {0,2,4,8,16}; G==0 falls back to the CK non-split path. CU count detected dynamically via get_cu_num().
  • Public knob: new optional num_splits arg on flash_attn_func (0 = heuristic, >=2 forces native if capable), threaded through the autograd chain. -ns/--num_splits added to op_tests/test_mha.py.

Performance (gfx942, MI300)

CK baseline (G=0) vs heuristic-selected native split-KV:

# Config (b, h, hk, s, sk, d) G=0 TFLOPS Heuristic G Heuristic TFLOPS Speedup
1 (1, 2, 2, 40000, 40000, 64) 307.95 G=4 412.38 +33.9%
2 (1, 16, 16, 128, 65535, 64) 16.38 G=16 173.70 +960%
3 (1, 8, 8, 32768, 32768, 64) 406.31 G=4 442.58 +8.9%
4 (8, 16, 16, 8192, 8192, 64) 454.53 G=0 454.53 +0%

Group 4 (already saturated: b=8) shows the heuristic correctly declines to split (G=0), incurring no regression where split-K can't help. Largest wins are the undersubscribed long-KV shapes (groups 1–2) where split-K fills the machine.

Correctness

Verified via op_tests/test_mha.py against attention_ref within bf16 tolerance:

  • Square 4096², non-square 256×8192 (both causal directions, full fwd+bwd), decode q=1 — native path confirmed in use (forced -ns).
  • S=40000 heuristic auto-trigger (no -ns) under kernel-trace; forward ref correctness.

Device-ISA parity with the reference build verified (required -mllvm -enable-post-misched=1 on the module's hip flags); .amdhsa descriptors identical, packed v_pk_mul_f32 restored.

Test plan

  • CI builds module_mha_fwd_native_splitkv (JIT) on gfx942
  • op_tests/test_mha.py targeted D64 shapes pass (forced -ns and heuristic)
  • Confirm fall-through paths (non-gfx942 / non-D64 / biased / varlen) unaffected

🤖 Generated with Claude Code

rocking5566 and others added 14 commits June 5, 2026 21:07
Comments in the vendored device headers referred to "the four existing
entries", "four call sites" and "[_varlen]" entry files that exist in the
upstream source but not on this branch (only the msk{0,1}_split producers
and combine ship here; fmha_fwd_d64_device is always instantiated with
IsSplit=true, IsVarlen=false). Reword to describe this branch's actual
entry set. Comments only; no code lines changed (ISA parity preserved).

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
… model

Port the two-regime occupancy heuristic tuned on 100 measured D64 shapes,
replacing the single hardcoded special case. G==0 falls back to the CK
non-split-KV kernel. Thread seqlen_q through and detect CU count dynamically
via get_cu_num() instead of assuming 304.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The heuristic now also returns 0 (CK fallback), so the dispatch falls
through for ns <= 1, not just ns == 1. Comment only.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 6, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3581 --add-label <label>

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new native HIP D64 BF16 split-K forward backend for aiter.flash_attn_func (gfx942-only) with a new num_splits knob and a heuristic selector, while leaving backward dispatch unchanged.

Changes:

  • Introduces a native split-K forward implementation (producer + combine) and exposes it via pybind/JIT as module_mha_fwd_native_splitkv.
  • Extends flash_attn_func / autograd plumbing with a new num_splits argument and a CU-count-based heuristic for choosing split count.
  • Updates MHA op tests to pass/parse the new --num_splits option.

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
op_tests/test_mha.py Adds --num_splits CLI option and threads it into test/benchmark calls.
csrc/pybind/mha_fwd_native_splitkv_pybind.cu Defines the pybind module entrypoint for the native split-K forward op.
csrc/kernels/mha_native/runner/params.hpp Adds kernarg structs and kernel geometry constants for the native split-K implementation.
csrc/kernels/mha_native/mha_native.h Declares the C++ API for mha_fwd_native_splitkv.
csrc/kernels/mha_native/mha_native_launch.h Declares host launch wrappers for split producers and combine kernels.
csrc/kernels/mha_native/mha_fwd_native_splitkv.cu Implements the Torch-facing C++ launcher (alloc scratch, launch producers + combine).
csrc/kernels/mha_native/fused/pipeline.hpp Implements the fused forward pipeline with optional split-K narrowing + scratch epilogue.
csrc/kernels/mha_native/fused/op_softmax.hpp Implements online softmax masking/max/sum/exp2 for the pipeline.
csrc/kernels/mha_native/fused/op_lds.hpp Implements DRAM↔LDS staging (async K copy, V load+permute+store) and sync helpers.
csrc/kernels/mha_native/fused/op_gemm.hpp Implements GEMM0 (Q·Kᵀ) and GEMM1 (P·V) using MFMA with required layout/swizzle.
csrc/kernels/mha_native/fused/op_epilog.hpp Implements bf16 epilogue and fp32 split scratch epilogue.
csrc/kernels/mha_native/fused/op_combine.hpp Implements the split-K combine pass (convex-combination reweighting + bf16 store).
csrc/kernels/mha_native/fmha_fwd_d64_bf16_msk1_split.cu Adds causal split producer kernel entry + host launcher.
csrc/kernels/mha_native/fmha_fwd_d64_bf16_msk0_split.cu Adds non-causal split producer kernel entry + host launcher.
csrc/kernels/mha_native/fmha_fwd_d64_bf16_combine.cu Adds combine kernel entry + host launcher.
csrc/include/rocm_ops.hpp Exposes mha_fwd_native_splitkv via a new pybind macro.
aiter/ops/mha.py Adds JIT binding for native split-K, capability gating, heuristic, and public num_splits plumbing.
aiter/jit/optCompilerConfig.json Registers the new JIT module build configuration and HIP flags/include paths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/kernels/mha_native/mha_fwd_native_splitkv.cu
Comment thread csrc/kernels/mha_native/mha_fwd_native_splitkv.cu
Comment thread aiter/ops/mha.py
Comment thread aiter/ops/mha.py
Comment thread op_tests/test_mha.py
rocking5566 and others added 4 commits June 8, 2026 14:39
mha_fwd_native_splitkv is a public aiter:: symbol exposed via pybind, so
C++/direct callers bypass the Python-side can_impl_fmha_native gating.
Validate k/v dtype (was q-only), 4-D rank, q/k/v last-dim contiguity, and
Hq % Hk == 0 (Hk > Hq would divide by zero in device GQA grouping).

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…point

out_opt is written via reinterpret_cast to bf16 using its own strides over
a grid covering B*Hq*Sq*D, so a wrong dtype/device/shape silently corrupts
memory or writes out of bounds. Validate bf16 dtype, same device as q, and
(B,Sq,Hq,D) shape in addition to the existing last-dim contiguity check.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Reject num_splits < 0 explicitly; previously negatives silently fell
  through to the heuristic path like 0, masking misconfiguration.
- Document num_splits in the flash_attn_func docstring (0=auto/heuristic,
  1=disable split-K, >=2 forces native split-K when applicable).
- Fix the test CLI --num_splits help text, which claimed ">=1 forces
  native" while the dispatch only routes to native when num_splits >= 2.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants