Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions aiter/ops/pa_sparse_prefill_opus.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,202 @@ def pa_sparse_prefill_opus(
return out


# ---------------------------------------------------------------------------
# FP8 (DeepSeek-V4 / asm-v4 layout) variant.
# head dim 512 = NOPE(448, fp8 e4m3, per-64-tile fp32 scale) + ROPE(64, bf16),
# for both Q and KV. First implementation dequantizes to bf16 scratch (via a
# standalone device kernel) then runs the bf16 attention kernel.
# ---------------------------------------------------------------------------

_FP8_NOPE = 448
_FP8_ROPE = 64
_FP8_FULL = 512
_FP8_NUM_TILES = 7


@compile_ops("module_pa_sparse_prefill_opus", develop=True)
def pa_sparse_prefill_opus_fp8_fwd(
q_nope: torch.Tensor,
q_rope: torch.Tensor,
q_scale: torch.Tensor,
unified_kv_nope: torch.Tensor,
unified_kv_rope: torch.Tensor,
unified_kv_scale: torch.Tensor,
kv_nope: torch.Tensor,
kv_rope: torch.Tensor,
kv_scale: torch.Tensor,
kv_indices_prefix: torch.Tensor,
kv_indptr_prefix: torch.Tensor,
kv_indices_extend: torch.Tensor,
kv_indptr_extend: torch.Tensor,
attn_sink: torch.Tensor,
q_bf16: torch.Tensor,
unified_kv_bf16: torch.Tensor,
kv_bf16: torch.Tensor,
out: torch.Tensor,
softmax_scale: float,
) -> None: ...


def pa_sparse_prefill_opus_fp8(
q_nope: torch.Tensor,
q_rope: torch.Tensor,
q_scale: torch.Tensor,
unified_kv_nope: torch.Tensor,
unified_kv_rope: torch.Tensor,
unified_kv_scale: torch.Tensor,
kv_nope: torch.Tensor,
kv_rope: torch.Tensor,
kv_scale: torch.Tensor,
kv_indices_prefix: torch.Tensor,
kv_indptr_prefix: torch.Tensor,
kv_indices_extend: torch.Tensor,
kv_indptr_extend: torch.Tensor,
attn_sink: torch.Tensor,
softmax_scale: float,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""FP8 (v4-layout) sparse prefill attention.

Q and KV are stored as NOPE(448, fp8 e4m3, per-64-tile fp32 scale) +
ROPE(64, bf16). Output is bf16 ``[N, H, 512]``. The bf16 dequant scratch
is allocated here and freed when it goes out of scope.

Args mirror :func:`pa_sparse_prefill_opus` but with each q/unified_kv/kv
split into ``*_nope`` (fp8 ``[..., 448]``), ``*_rope`` (bf16 ``[..., 64]``)
and ``*_scale`` (fp32 ``[..., 7]``).
"""
gfx = get_gfx_runtime()
if gfx != "gfx950":
raise RuntimeError(f"pa_sparse_prefill_opus_fp8 requires gfx950, got {gfx}")

if q_nope.dtype != torch.float8_e4m3fn:
raise RuntimeError(f"q_nope must be fp8 e4m3fn, got {q_nope.dtype}")
if q_nope.size(-1) != _FP8_NOPE or q_rope.size(-1) != _FP8_ROPE:
raise RuntimeError(
f"expected nope last dim {_FP8_NOPE} and rope last dim {_FP8_ROPE}, "
f"got {q_nope.size(-1)} / {q_rope.size(-1)}"
)

n, h = q_nope.size(0), q_nope.size(1)
total_pages = unified_kv_nope.size(0)
total_tokens = kv_nope.size(0)
dev = q_nope.device

q_bf16 = torch.empty((n, h, _FP8_FULL), dtype=torch.bfloat16, device=dev)
unified_kv_bf16 = torch.empty(
(total_pages, _FP8_FULL), dtype=torch.bfloat16, device=dev
)
kv_bf16 = torch.empty((total_tokens, _FP8_FULL), dtype=torch.bfloat16, device=dev)
if out is None:
out = torch.empty((n, h, _FP8_FULL), dtype=torch.bfloat16, device=dev)

pa_sparse_prefill_opus_fp8_fwd(
q_nope,
q_rope,
q_scale,
unified_kv_nope,
unified_kv_rope,
unified_kv_scale,
kv_nope,
kv_rope,
kv_scale,
kv_indices_prefix,
kv_indptr_prefix,
kv_indices_extend,
kv_indptr_extend,
attn_sink,
q_bf16,
unified_kv_bf16,
kv_bf16,
out,
float(softmax_scale),
)
return out


@compile_ops("module_pa_sparse_prefill_opus", develop=True)
def pa_sparse_prefill_opus_fp8_fused_fwd(
q_nope: torch.Tensor,
q_rope: torch.Tensor,
q_scale: torch.Tensor,
unified_kv_nope: torch.Tensor,
unified_kv_rope: torch.Tensor,
unified_kv_scale: torch.Tensor,
kv_nope: torch.Tensor,
kv_rope: torch.Tensor,
kv_scale: torch.Tensor,
kv_indices_prefix: torch.Tensor,
kv_indptr_prefix: torch.Tensor,
kv_indices_extend: torch.Tensor,
kv_indptr_extend: torch.Tensor,
attn_sink: torch.Tensor,
out: torch.Tensor,
softmax_scale: float,
) -> None: ...


def pa_sparse_prefill_opus_fp8_fused(
q_nope: torch.Tensor,
q_rope: torch.Tensor,
q_scale: torch.Tensor,
unified_kv_nope: torch.Tensor,
unified_kv_rope: torch.Tensor,
unified_kv_scale: torch.Tensor,
kv_nope: torch.Tensor,
kv_rope: torch.Tensor,
kv_scale: torch.Tensor,
kv_indices_prefix: torch.Tensor,
kv_indptr_prefix: torch.Tensor,
kv_indices_extend: torch.Tensor,
kv_indptr_extend: torch.Tensor,
attn_sink: torch.Tensor,
softmax_scale: float,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Fused fp8 (v4-layout) sparse prefill attention (no bf16 KV scratch).

QK-nope runs in fp8 MFMA with software per-64-tile scale; QK-rope in bf16;
PV in bf16 with on-chip V dequant. H must be a multiple of 16.
"""
gfx = get_gfx_runtime()
if gfx != "gfx950":
raise RuntimeError(
f"pa_sparse_prefill_opus_fp8_fused requires gfx950, got {gfx}"
)
if q_nope.dtype != torch.float8_e4m3fn:
raise RuntimeError(f"q_nope must be fp8 e4m3fn, got {q_nope.dtype}")
n, h = q_nope.size(0), q_nope.size(1)
if h % 16 != 0:
raise RuntimeError(f"H must be a multiple of 16, got {h}")
if out is None:
out = torch.empty((n, h, _FP8_FULL), dtype=torch.bfloat16, device=q_nope.device)
pa_sparse_prefill_opus_fp8_fused_fwd(
q_nope,
q_rope,
q_scale,
unified_kv_nope,
unified_kv_rope,
unified_kv_scale,
kv_nope,
kv_rope,
kv_scale,
kv_indices_prefix,
kv_indptr_prefix,
kv_indices_extend,
kv_indptr_extend,
attn_sink,
out,
float(softmax_scale),
)
return out


__all__ = [
"pa_sparse_prefill_opus_fwd",
"pa_sparse_prefill_opus",
"pa_sparse_prefill_opus_fp8_fwd",
"pa_sparse_prefill_opus_fp8",
"pa_sparse_prefill_opus_fp8_fused_fwd",
"pa_sparse_prefill_opus_fp8_fused",
]
Loading
Loading