Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 101 additions & 23 deletions aiter/aot/flydsl/chunk_gdn_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

Reads the offline-tuned BV lookup table
``aiter/ops/flydsl/chunk_gdn_h_tuned.csv`` (the same file consumed at
runtime by ``_lookup_tuned_bv`` in ``linear_attention_prefill_kernels``),
extracts every unique compile-time configuration, and pre-compiles it
into the FlyDSL disk cache so that the first inference call does not pay
the JIT cost.
runtime by ``_lookup_csv_bv`` / ``_heuristic_bv`` in
``linear_attention_prefill_kernels``), extracts every unique compile-time
configuration, and pre-compiles it into the FlyDSL disk cache so that the
first inference call does not pay the JIT cost.

Each csv row is compiled twice -- once with ``STATE_DTYPE_BF16=False``
(legacy f32-state runtime path) and once with ``STATE_DTYPE_BF16=True``
Expand Down Expand Up @@ -59,6 +59,22 @@
)
from aiter.ops.flydsl.kernels.chunk_gated_delta_h import compile_chunk_gated_delta_h

# Mirror of ``linear_attention_prefill_kernels._BV_CANDIDATES`` /
# ``_legal_bv_candidates``. We intentionally do NOT import them from that
# module: importing the runtime host wrapper pulls in ``triton`` and the
# Triton ``gated_delta_rule`` package, whose ``__init__`` runs autotune
# config probing that calls into the Triton driver. In an AOT/CI host with
# no GPU visible (``HIP_VISIBLE_DEVICES=``) that probe raises
# ``RuntimeError: 0 active drivers``. The BV candidate set is a tiny,
# stable constant, so we duplicate the pure logic here to keep the AOT
# entrypoint free of any torch/triton import. Keep in sync with the
# runtime definition.
_BV_CANDIDATES = [16, 32, 64]


def _legal_bv_candidates(V: int) -> list[int]:
return [c for c in _BV_CANDIDATES if c <= V and V % c == 0]

# Default tuned table lives next to the kernel host wrapper.
_DEFAULT_CSV = (
Path(__file__).resolve().parents[2] / "ops" / "flydsl" / "chunk_gdn_h_tuned.csv"
Expand Down Expand Up @@ -94,12 +110,44 @@ def _parse_bool(s: str) -> bool:
def parse_csv(csv_path: str) -> list[dict[str, Any]]:
"""Parse the chunk_gdn_h tuned csv and return unique compile jobs.

Each row already carries every compile-time switch the kernel cares
about (K/V/BT/H/Hg/use_g/use_gk/use_h0/store_fs/save_vn/is_varlen/
wu_contig) plus the offline-tuned ``BV``. We only keep the fields
that actually influence MLIR compilation; ``T_flat``/``N`` and
``duration`` are dropped (they affect the host launch grid, not the
compiled artifact).
Each row carries a *shape family* + compile-time switches
(K/V/BT/H/Hg/use_g/use_gk/use_h0/store_fs/save_vn/is_varlen/wu_contig).
``T_flat``/``N``/``duration`` are dropped (they affect the host launch
grid, not the compiled artifact).

BV handling (runtime-aligned, see ``linear_attention_prefill_kernels``):
Runtime BV is chosen by the rule-based ``_heuristic_bv``, NOT by the
csv ``BV`` column -- so a single tuned-csv row can resolve to
*several* different BVs at runtime depending on ``(T_flat, N,
min_seqlen)`` (e.g. the H32/Hg16 carve-outs emit BV in {16, 32, 64}).
Pre-compiling only the csv's ``BV`` value would leave the other
rule-reachable BVs to pay a first-call JIT. To keep AOT coverage a
superset of every BV the runtime rule can pick, we ignore the csv
``BV`` column for selection and instead fan each shape family out
over **all legal BVs** for its ``V`` (``_legal_bv_candidates`` --
the same helper the runtime selector uses). This is the chunk-gdn-h
analogue of how the MoE/GEMM AOT paths cover their full runtime
candidate space.

is_varlen / store_fs handling:
``IS_VARLEN`` (cu_seqlens read path) and ``STORE_FINAL_STATE``
(final-state write-back) are compile-time flags; each True / False
choice is a distinct compiled artifact. A tuned-csv row only pins
one combination of them, but at runtime a given shape family can be
called either batched or varlen, with or without final-state output
(e.g. the H32/Hg16 carve-out shapes are exercised varlen +
output_final_state=True, while the csv rows for that family are
batched + store_fs=False). To make AOT coverage independent of those
runtime choices (and decoupled from the runtime rule's family list),
we pre-compile BOTH ``is_varlen`` AND BOTH ``store_fs`` variants for
EVERY shape family.

g_log2_scaled handling:
``G_IS_LOG2_SCALED`` is also a compile-time flag (see
``_compile_chunk_gdn_h_to_cache``). The runtime wrapper defaults to
``use_exp2=True`` -> ``g_log2_scaled=True``, which is the path the
end-to-end prefill and the tests take, so AOT pins it to ``True``.
The (rarer) ``use_exp2=False`` path is left to a first-call JIT.
"""
jobs: list[dict[str, Any]] = []
seen: set[tuple] = set()
Expand All @@ -112,36 +160,39 @@ def parse_csv(csv_path: str) -> list[dict[str, Any]]:
continue

try:
bv = int(row["BV"])
k = int(row["K"])
v = int(row["V"])
except (KeyError, TypeError, ValueError) as e:
print(f" [WARN] malformed row in {csv_path}: {e}")
continue

if v % bv != 0 or bv > v:
# Fan out over every BV the runtime rule could select for this V,
# rather than trusting the (now advisory) csv ``BV`` column.
legal_bvs = _legal_bv_candidates(v)
if not legal_bvs:
print(
f" [WARN] BV={bv} does not divide V={v}, skipping row "
f" [WARN] no legal BV candidate for V={v}, skipping row "
f"in {csv_path}"
)
continue

try:
job = {
base_job = {
"dtype": dtype,
"arch": row.get("arch") or CHUNK_GDN_H_AOT_ARCH_DEFAULT,
"K": k,
"V": v,
"BT": int(row.get("BT") or 64),
"BV": bv,
"H": int(row["H"]),
"Hg": int(row["Hg"]),
"use_g": _parse_bool(row.get("use_g") or "True"),
"use_gk": _parse_bool(row.get("use_gk") or "False"),
"use_h0": _parse_bool(row.get("use_h0") or "True"),
"store_fs": _parse_bool(row.get("store_fs") or "False"),
"save_vn": _parse_bool(row.get("save_vn") or "True"),
"is_varlen": _parse_bool(row.get("is_varlen") or "False"),
# ``is_varlen`` and ``store_fs`` are set per-variant in
# the fan-out loop below (both False and True are
# pre-compiled for each), so their csv columns are not
# read here.
"wu_contig": _parse_bool(row.get("wu_contig") or "True"),
# state dtype is not tracked in the tuned csv yet; default
# f32 here, then main() unconditionally fans out into a
Expand All @@ -152,11 +203,25 @@ def parse_csv(csv_path: str) -> list[dict[str, Any]]:
print(f" [WARN] malformed row in {csv_path}: {e}")
continue

key = job_identity(job)
if key in seen:
continue
seen.add(key)
jobs.append(job)
# Cover both compiled IS_VARLEN and STORE_FINAL_STATE paths for
# every family, so AOT coverage is independent of whether the
# caller runs batched/varlen and whether it requests the final
# state -- both are compile-time flags that produce distinct
# artifacts, and a single tuned-csv row only pins one combination.
for is_varlen in (False, True):
for store_fs in (False, True):
for bv in legal_bvs:
job = dict(
base_job,
BV=bv,
is_varlen=is_varlen,
store_fs=store_fs,
)
key = job_identity(job)
if key in seen:
continue
seen.add(key)
jobs.append(job)

return jobs

Expand Down Expand Up @@ -218,7 +283,12 @@ def _compile_chunk_gdn_h_to_cache(
v = torch.empty((B, H, T_flat, V), device=dev, dtype=torch_dtype)
w = torch.empty((B, H, T_flat, K), device=dev, dtype=torch_dtype)
v_new = torch.empty((B, H, T_flat, V), device=dev, dtype=torch_dtype)
g = torch.empty((B * T_flat, H), device=dev, dtype=torch.float32)
# ``g`` is head-major ``[B, H, T_flat]`` (offset = i_h * T_flat + (bos+row),
# stride=1) to match the K5 kernel's flat addressing in
# ``kernels/chunk_gated_delta_h.py`` (and Triton VK / HIP K5). ``gk`` stays
# token-major ``[B*T_flat, H, K]`` per its own per-K decay addressing
# (offset = (bos+row)*H*K + i_h*K + k).
g = torch.empty((B, H, T_flat), device=dev, dtype=torch.float32)
gk = torch.empty((B * T_flat, H, K), device=dev, dtype=torch.float32)
h = torch.empty((B, max(T_flat // BT, 1), H, V, K), device=dev, dtype=torch_dtype)
h0 = torch.empty((N, H, V, K), device=dev, dtype=state_dtype)
Expand Down Expand Up @@ -246,6 +316,14 @@ def _compile_chunk_gdn_h_to_cache(
IS_VARLEN=is_varlen,
WU_CONTIGUOUS=wu_contig,
STATE_DTYPE_BF16=state_bf16,
# Match the runtime default path: the K5 wrapper defaults to
# ``use_exp2=True`` (and the end-to-end prefill passes the Triton K1
# ``use_exp2`` through, which is True by default), so the runtime
# compile key carries ``G_IS_LOG2_SCALED=True``. Pre-compiling with
# the kernel default (False) would never hit that key -- the whole
# AOT cache would miss on the default path. We pin True here to cover
# it; the (rarer) ``use_exp2=False`` path still JITs on first call.
G_IS_LOG2_SCALED=True,
)

grid_v = (V + BV - 1) // BV
Expand Down
Loading
Loading