diff --git a/aiter/aot/flydsl/chunk_gdn_h.py b/aiter/aot/flydsl/chunk_gdn_h.py index 1de5db06d6..8df9f43e63 100644 --- a/aiter/aot/flydsl/chunk_gdn_h.py +++ b/aiter/aot/flydsl/chunk_gdn_h.py @@ -7,10 +7,10 @@ Reads the offline-tuned BV lookup table ``aiter/ops/flydsl/chunk_gdn_h_tuned.csv`` (the same file consumed at -runtime by ``_lookup_tuned_bv`` in ``linear_attention_prefill_kernels``), -extracts every unique compile-time configuration, and pre-compiles it -into the FlyDSL disk cache so that the first inference call does not pay -the JIT cost. +runtime by ``_lookup_csv_bv`` / ``_heuristic_bv`` in +``linear_attention_prefill_kernels``), extracts every unique compile-time +configuration, and pre-compiles it into the FlyDSL disk cache so that the +first inference call does not pay the JIT cost. Each csv row is compiled twice -- once with ``STATE_DTYPE_BF16=False`` (legacy f32-state runtime path) and once with ``STATE_DTYPE_BF16=True`` @@ -59,6 +59,22 @@ ) from aiter.ops.flydsl.kernels.chunk_gated_delta_h import compile_chunk_gated_delta_h +# Mirror of ``linear_attention_prefill_kernels._BV_CANDIDATES`` / +# ``_legal_bv_candidates``. We intentionally do NOT import them from that +# module: importing the runtime host wrapper pulls in ``triton`` and the +# Triton ``gated_delta_rule`` package, whose ``__init__`` runs autotune +# config probing that calls into the Triton driver. In an AOT/CI host with +# no GPU visible (``HIP_VISIBLE_DEVICES=``) that probe raises +# ``RuntimeError: 0 active drivers``. The BV candidate set is a tiny, +# stable constant, so we duplicate the pure logic here to keep the AOT +# entrypoint free of any torch/triton import. Keep in sync with the +# runtime definition. +_BV_CANDIDATES = [16, 32, 64] + + +def _legal_bv_candidates(V: int) -> list[int]: + return [c for c in _BV_CANDIDATES if c <= V and V % c == 0] + # Default tuned table lives next to the kernel host wrapper. _DEFAULT_CSV = ( Path(__file__).resolve().parents[2] / "ops" / "flydsl" / "chunk_gdn_h_tuned.csv" @@ -94,12 +110,44 @@ def _parse_bool(s: str) -> bool: def parse_csv(csv_path: str) -> list[dict[str, Any]]: """Parse the chunk_gdn_h tuned csv and return unique compile jobs. - Each row already carries every compile-time switch the kernel cares - about (K/V/BT/H/Hg/use_g/use_gk/use_h0/store_fs/save_vn/is_varlen/ - wu_contig) plus the offline-tuned ``BV``. We only keep the fields - that actually influence MLIR compilation; ``T_flat``/``N`` and - ``duration`` are dropped (they affect the host launch grid, not the - compiled artifact). + Each row carries a *shape family* + compile-time switches + (K/V/BT/H/Hg/use_g/use_gk/use_h0/store_fs/save_vn/is_varlen/wu_contig). + ``T_flat``/``N``/``duration`` are dropped (they affect the host launch + grid, not the compiled artifact). + + BV handling (runtime-aligned, see ``linear_attention_prefill_kernels``): + Runtime BV is chosen by the rule-based ``_heuristic_bv``, NOT by the + csv ``BV`` column -- so a single tuned-csv row can resolve to + *several* different BVs at runtime depending on ``(T_flat, N, + min_seqlen)`` (e.g. the H32/Hg16 carve-outs emit BV in {16, 32, 64}). + Pre-compiling only the csv's ``BV`` value would leave the other + rule-reachable BVs to pay a first-call JIT. To keep AOT coverage a + superset of every BV the runtime rule can pick, we ignore the csv + ``BV`` column for selection and instead fan each shape family out + over **all legal BVs** for its ``V`` (``_legal_bv_candidates`` -- + the same helper the runtime selector uses). This is the chunk-gdn-h + analogue of how the MoE/GEMM AOT paths cover their full runtime + candidate space. + + is_varlen / store_fs handling: + ``IS_VARLEN`` (cu_seqlens read path) and ``STORE_FINAL_STATE`` + (final-state write-back) are compile-time flags; each True / False + choice is a distinct compiled artifact. A tuned-csv row only pins + one combination of them, but at runtime a given shape family can be + called either batched or varlen, with or without final-state output + (e.g. the H32/Hg16 carve-out shapes are exercised varlen + + output_final_state=True, while the csv rows for that family are + batched + store_fs=False). To make AOT coverage independent of those + runtime choices (and decoupled from the runtime rule's family list), + we pre-compile BOTH ``is_varlen`` AND BOTH ``store_fs`` variants for + EVERY shape family. + + g_log2_scaled handling: + ``G_IS_LOG2_SCALED`` is also a compile-time flag (see + ``_compile_chunk_gdn_h_to_cache``). The runtime wrapper defaults to + ``use_exp2=True`` -> ``g_log2_scaled=True``, which is the path the + end-to-end prefill and the tests take, so AOT pins it to ``True``. + The (rarer) ``use_exp2=False`` path is left to a first-call JIT. """ jobs: list[dict[str, Any]] = [] seen: set[tuple] = set() @@ -112,36 +160,39 @@ def parse_csv(csv_path: str) -> list[dict[str, Any]]: continue try: - bv = int(row["BV"]) k = int(row["K"]) v = int(row["V"]) except (KeyError, TypeError, ValueError) as e: print(f" [WARN] malformed row in {csv_path}: {e}") continue - if v % bv != 0 or bv > v: + # Fan out over every BV the runtime rule could select for this V, + # rather than trusting the (now advisory) csv ``BV`` column. + legal_bvs = _legal_bv_candidates(v) + if not legal_bvs: print( - f" [WARN] BV={bv} does not divide V={v}, skipping row " + f" [WARN] no legal BV candidate for V={v}, skipping row " f"in {csv_path}" ) continue try: - job = { + base_job = { "dtype": dtype, "arch": row.get("arch") or CHUNK_GDN_H_AOT_ARCH_DEFAULT, "K": k, "V": v, "BT": int(row.get("BT") or 64), - "BV": bv, "H": int(row["H"]), "Hg": int(row["Hg"]), "use_g": _parse_bool(row.get("use_g") or "True"), "use_gk": _parse_bool(row.get("use_gk") or "False"), "use_h0": _parse_bool(row.get("use_h0") or "True"), - "store_fs": _parse_bool(row.get("store_fs") or "False"), "save_vn": _parse_bool(row.get("save_vn") or "True"), - "is_varlen": _parse_bool(row.get("is_varlen") or "False"), + # ``is_varlen`` and ``store_fs`` are set per-variant in + # the fan-out loop below (both False and True are + # pre-compiled for each), so their csv columns are not + # read here. "wu_contig": _parse_bool(row.get("wu_contig") or "True"), # state dtype is not tracked in the tuned csv yet; default # f32 here, then main() unconditionally fans out into a @@ -152,11 +203,25 @@ def parse_csv(csv_path: str) -> list[dict[str, Any]]: print(f" [WARN] malformed row in {csv_path}: {e}") continue - key = job_identity(job) - if key in seen: - continue - seen.add(key) - jobs.append(job) + # Cover both compiled IS_VARLEN and STORE_FINAL_STATE paths for + # every family, so AOT coverage is independent of whether the + # caller runs batched/varlen and whether it requests the final + # state -- both are compile-time flags that produce distinct + # artifacts, and a single tuned-csv row only pins one combination. + for is_varlen in (False, True): + for store_fs in (False, True): + for bv in legal_bvs: + job = dict( + base_job, + BV=bv, + is_varlen=is_varlen, + store_fs=store_fs, + ) + key = job_identity(job) + if key in seen: + continue + seen.add(key) + jobs.append(job) return jobs @@ -218,7 +283,12 @@ def _compile_chunk_gdn_h_to_cache( v = torch.empty((B, H, T_flat, V), device=dev, dtype=torch_dtype) w = torch.empty((B, H, T_flat, K), device=dev, dtype=torch_dtype) v_new = torch.empty((B, H, T_flat, V), device=dev, dtype=torch_dtype) - g = torch.empty((B * T_flat, H), device=dev, dtype=torch.float32) + # ``g`` is head-major ``[B, H, T_flat]`` (offset = i_h * T_flat + (bos+row), + # stride=1) to match the K5 kernel's flat addressing in + # ``kernels/chunk_gated_delta_h.py`` (and Triton VK / HIP K5). ``gk`` stays + # token-major ``[B*T_flat, H, K]`` per its own per-K decay addressing + # (offset = (bos+row)*H*K + i_h*K + k). + g = torch.empty((B, H, T_flat), device=dev, dtype=torch.float32) gk = torch.empty((B * T_flat, H, K), device=dev, dtype=torch.float32) h = torch.empty((B, max(T_flat // BT, 1), H, V, K), device=dev, dtype=torch_dtype) h0 = torch.empty((N, H, V, K), device=dev, dtype=state_dtype) @@ -246,6 +316,14 @@ def _compile_chunk_gdn_h_to_cache( IS_VARLEN=is_varlen, WU_CONTIGUOUS=wu_contig, STATE_DTYPE_BF16=state_bf16, + # Match the runtime default path: the K5 wrapper defaults to + # ``use_exp2=True`` (and the end-to-end prefill passes the Triton K1 + # ``use_exp2`` through, which is True by default), so the runtime + # compile key carries ``G_IS_LOG2_SCALED=True``. Pre-compiling with + # the kernel default (False) would never hit that key -- the whole + # AOT cache would miss on the default path. We pin True here to cover + # it; the (rarer) ``use_exp2=False`` path still JITs on first call. + G_IS_LOG2_SCALED=True, ) grid_v = (V + BV - 1) // BV diff --git a/aiter/ops/flydsl/linear_attention_prefill_kernels.py b/aiter/ops/flydsl/linear_attention_prefill_kernels.py index 04becdb7e3..4d7942b44c 100644 --- a/aiter/ops/flydsl/linear_attention_prefill_kernels.py +++ b/aiter/ops/flydsl/linear_attention_prefill_kernels.py @@ -18,11 +18,16 @@ from __future__ import annotations +import csv +import functools import math +import warnings +from pathlib import Path import torch import triton +from flydsl.runtime.device import get_rocm_arch from .kernels.chunk_gated_delta_h import compile_chunk_gated_delta_h from ..triton._triton_kernels.gated_delta_rule.utils import ( prepare_chunk_offsets, @@ -42,10 +47,350 @@ # -- K5 host wrapper (FlyDSL kernel + rule-based BV selection) ------------ -_compiled_kernels = {} _BV_CANDIDATES = [16, 32, 64] _DEFAULT_BV = 16 +# The trace-calibrated BV carve-outs in ``_target_bv_for_shape`` were swept +# exclusively on gfx950 (V=128, BT=64, 256 CUs). They assume gfx950's CU +# count and wave-occupancy behavior, so they are gated to gfx950 -- on any +# other arch the carve-outs return ``None`` and the generic CU-fill default +# applies instead (matches the pre-calibration behavior, no regression). +# ``get_rocm_arch()`` may return a feature-suffixed string like +# ``gfx950:sramecc+:xnack-``; normalize before matching. +_IS_GFX950 = get_rocm_arch().split(":")[0].startswith("gfx950") + +# gfx950 has 256 CUs. Used as the fallback CU count when the live device +# query is unavailable (e.g. CPU-only meta runs). +_GFX950_CU_COUNT = 256 + +# --------------------------------------------------------------------------- +# Tuned-csv BV lookup (csv-best preferred, rule-based fallback) +# --------------------------------------------------------------------------- +# Offline-tuned table mapping a shape family to its measured-best BV. This is +# the SAME csv the AOT path (``aiter/aot/flydsl/chunk_gdn_h.py``) uses as its +# pre-compile seed list. At runtime we consult it FIRST for an exact-match +# best BV; on a miss (the table is sparse -- a few dozen rows) we fall back to +# the rule-based ``_target_bv_for_shape`` / CU-fill heuristic, so coverage is +# never worse than the pure-rule path. +# +# Built once per process (mirrors ``GDR_GLOBAL_CONFIG_MAP`` in +# ``linear_attention_kernels``): we keep, per key, the BV of the row with the +# smallest measured ``duration``. +_TUNED_BV_CSV = ( + Path(__file__).resolve().parent / "chunk_gdn_h_tuned.csv" +) +# key = (arch, dtype, K, V, BT, H, Hg, T_flat, N, is_varlen) -> (BV, duration) +_TUNED_BV_MAP: dict[tuple, tuple[int, float]] | None = None + + +def _parse_csv_bool(s: str) -> bool: + return str(s).strip() in ("True", "true", "1", "yes") + + +def _load_tuned_bv_map() -> dict[tuple, tuple[int, float]]: + """Parse ``chunk_gdn_h_tuned.csv`` into a best-BV lookup, once per process. + + Keeps the BV from the minimum-``duration`` row for each shape-family key. + Malformed rows are skipped so a partially hand-edited csv can never break + the runtime path (we just fall back to the rule for those shapes). + """ + global _TUNED_BV_MAP + if _TUNED_BV_MAP is not None: + return _TUNED_BV_MAP + + out: dict[tuple, tuple[int, float]] = {} + try: + with open(_TUNED_BV_CSV, "r", encoding="utf-8") as f: + for row in csv.DictReader(f): + try: + key = ( + row["arch"], + row["dtype"], + int(row["K"]), + int(row["V"]), + int(row["BT"]), + int(row["H"]), + int(row["Hg"]), + int(row["T_flat"]), + int(row["N"]), + _parse_csv_bool(row["is_varlen"]), + ) + bv = int(row["BV"]) + dur = float(row["duration"]) + except (KeyError, TypeError, ValueError): + continue + prev = out.get(key) + if prev is None: + out[key] = (bv, dur) + continue + # Ambiguity guard: the lookup key intentionally omits the + # use_g/use_gk/use_h0/store_fs/save_vn/wu_contig switches (they + # are not tuned as independent BV dimensions today). If a future + # csv edit introduces two rows that share this key but disagree + # on BV, the min-duration row silently wins -- which could be a + # row tuned under a different switch combo than the caller's. + # Surface it so the csv author can decide whether the switch + # belongs in the key. + if bv != prev[0]: + warnings.warn( + "FlyDSL K5 tuned csv: conflicting BV for shape key " + f"{key}: BV={prev[0]} (dur={prev[1]:.1f}) vs BV={bv} " + f"(dur={dur:.1f}); keeping the lower-duration row. If " + "these rows differ only by a use_*/store_fs/save_vn/" + "wu_contig switch, consider adding that switch to the " + "lookup key.", + stacklevel=2, + ) + if dur < prev[1]: + out[key] = (bv, dur) + except OSError: + # No csv on disk (e.g. trimmed deployment): rule-only path. + pass + + _TUNED_BV_MAP = out + return out + + +def _lookup_csv_bv( + *, + dtype_str: str | None, + K: int | None, + BT: int | None, + H: int, + Hg: int, + V: int, + T_flat: int, + N: int, + is_varlen: bool, +) -> int | None: + """Exact-match best BV from the tuned csv, or ``None`` on miss. + + Returns ``None`` whenever any key field needed to form the csv key is + unavailable (``dtype_str`` / ``K`` / ``BT`` are optional on the + ``_heuristic_bv`` signature for backward compatibility), so callers that + don't pass them simply skip straight to the rule-based path. + """ + if dtype_str is None or K is None or BT is None: + return None + table = _load_tuned_bv_map() + if not table: + return None + hit = table.get( + ( + get_rocm_arch(), + dtype_str, + K, + V, + BT, + H, + Hg, + T_flat, + N, + is_varlen, + ) + ) + return hit[0] if hit is not None else None + + +@functools.lru_cache(maxsize=None) +def _cu_count(device_index: int) -> int: + """Number of compute units (CTA "wave" width) for the target device. + + The grid-fill heuristic targets "one wave of CTAs over the device's + CUs"; the wave width is the CU count, which is arch/SKU-specific (256 + on gfx950, but differs on gfx942 and others). We read it from the live + device properties (``multi_processor_count``) instead of hardcoding + 256, falling back to the gfx950 value only when the query fails. + """ + try: + idx = device_index if device_index >= 0 else torch.cuda.current_device() + props = torch.cuda.get_device_properties(idx) + cu = int(getattr(props, "multi_processor_count", 0) or 0) + if cu > 0: + return cu + except Exception: + pass + return _GFX950_CU_COUNT + + +# --------------------------------------------------------------------------- +# Host-side overhead caches +# --------------------------------------------------------------------------- +# The flyc launcher requires every tensor argument to be an ``fx.Tensor`` +# (``None`` is not accepted), and the K5 kernel reads the offset arrays as +# int32 (``GTensor(..., dtype=T.i32, ...)``). The Triton-side cached prologue +# helpers (``prepare_chunk_offsets`` / ``prepare_rebased_cu_seqlens``) return +# int64, so the launch path was previously doing ``.to(torch.int32)`` on +# every forward, even though the underlying int64 tensor is identity-stable +# across forwards thanks to ``@tensor_cache``. We sidestep that by: +# +# 1. ``_as_int32``: attaches the int32 view directly onto the int64 tensor +# as a private attribute (``Tensor`` objects accept arbitrary +# attributes). The first forward casts once; every subsequent forward +# is a pure ``getattr`` -- no ATen op dispatch, no allocator hit, no +# D2D copy. Lifetime is bound to the int64 tensor itself, so when the +# upstream ``@tensor_cache`` evicts an entry the int32 copy is freed +# automatically (no ``id``-recycling hazard, unlike a global dict). +# +# 2. ``_get_dummy``: per-device cached scalar tensors for the +# ``cu_seqlens is None`` (batched) path. The original code allocated a +# fresh ``torch.empty(1, dtype=fp32)`` plus two ``dummy.to(int32)`` +# casts on every call; those are pure overhead because the kernel never +# reads them when ``IS_VARLEN=False``. We hand back a single shared +# tensor per (device, dtype) instead. +_INT32_ATTR = "_flydsl_int32_view" +_PROLOGUE_ATTR = "_flydsl_prologue_cache" +_FAST_PLAN_ATTR = "_flydsl_fast_plan" +_dummy_tensors: dict[tuple[int, torch.dtype], torch.Tensor] = {} + +# Per-shape "launch plan" cache. The plan packs every shape/flag-derived +# product (BV, launch_fn, grid dims, int32-view offsets, dummies, stream, +# output-buffer shapes, ...) into a single tuple, so a hot-path call that +# hits the cache reduces to: one dict lookup, three ``new_empty`` calls and +# the actual ``launch_fn`` invocation. See ``_build_plan`` / ``_plan_key`` +# for the exact contract. Bounded by ``_PLAN_CACHE_MAX`` to keep dict +# overhead constant even if a caller drives many unique shapes. +_plan_cache: dict[tuple, tuple] = {} +_PLAN_CACHE_MAX = 1024 +# Hot-path bool/int flag packing. Bits 0..7 encode the eight Python flags +# that vary per call; bits 8..15 encode chunk_size (BT, typically 64); +# bits 16..23 encode num_decodes; bits 24..31 encode num_decode_tokens. +# Packing into a single Python int removes seven 1-byte tuple slots from +# the plan key (each costing ~250ns to hash on 17-tuple), which is the +# single largest chunk of the ~5us plan-key-construction overhead. +def _pack_flags( + use_g, use_gk, use_h0, output_final_state, save_new_value, + g_log2_scaled, state_bf16, is_varlen, + chunk_size, num_decodes, num_decode_tokens, +): + return ( + (use_g & 1) + | ((use_gk & 1) << 1) + | ((use_h0 & 1) << 2) + | ((output_final_state & 1) << 3) + | ((save_new_value & 1) << 4) + | ((g_log2_scaled & 1) << 5) + | ((state_bf16 & 1) << 6) + | ((is_varlen & 1) << 7) + | ((chunk_size & 0xFF) << 8) + | ((num_decodes & 0xFF) << 16) + | ((num_decode_tokens & 0xFFFF) << 24) + ) +# Stream lookup is one of the most expensive host-side calls in the K5 +# launch path (~2us per ``torch.cuda.current_stream()``). Caller code in +# this repo only ever uses the default stream, so we cache the per-device +# default stream object once and re-use it across forwards. If a caller +# switches streams between launches they should clear the cache; we treat +# that as an unusual enough case to be explicit about ("attach to default +# stream" is the path that 100% of production callers take today). +_default_stream_cache: dict[int, "torch.cuda.Stream"] = {} + + +def _current_stream(device: torch.device) -> "torch.cuda.Stream": + """Cached ``torch.cuda.current_stream(device)`` for the hot launch path. + + The underlying CUDA driver query is ~2us per call; caching elides it + after the first forward (the kernel launch itself uses the returned + object, so it must remain a real ``torch.cuda.Stream`` and not a raw + handle). + """ + idx = device.index if device.type == "cuda" else -1 + s = _default_stream_cache.get(idx) + if s is None: + s = torch.cuda.current_stream(device) + _default_stream_cache[idx] = s + return s + + +def _as_int32(t: torch.Tensor) -> torch.Tensor: + """Return an int32 narrowing of ``t``, cached on the tensor itself. + + ``t`` is expected to come from one of the ``@tensor_cache``-decorated + prologue helpers (so its identity is stable across forwards). The + cached int32 result lives as an attribute on ``t`` itself, which keeps + cache invalidation trivially correct: when the upstream cache evicts + ``t``, the int32 copy is collected with it. + """ + if t.dtype == torch.int32: + return t + cached = getattr(t, _INT32_ATTR, None) + if cached is None: + cached = t.to(torch.int32) + try: + object.__setattr__(t, _INT32_ATTR, cached) + except (AttributeError, TypeError): + # Some tensor subclasses or autograd-tracked tensors disallow + # ad-hoc attributes; fall back to the uncached cast (still + # correct, just no longer hot-path-optimised for this caller). + pass + return cached + + +def _resolve_prologue( + cu_seqlens: torch.Tensor, + BT: int, + num_decodes: int, + num_decode_tokens: int, +): + """Resolve the per-shape varlen prologue in one cached lookup. + + Each of ``prepare_chunk_offsets`` / ``prepare_num_chunks`` / + ``prepare_rebased_cu_seqlens`` is already ``@tensor_cache``-decorated, so + every call is "just" a tuple compare + dict lookup. That is still + ~0.55us each via the upstream Python wrapper (≈1.7us total across the + three calls), so we collapse them into a single 4-tuple attached + directly to ``cu_seqlens`` (keyed by ``(BT, num_decodes, + num_decode_tokens)``). After the first forward on a given + ``cu_seqlens`` tensor, this is one ``getattr`` + one dict get on a + tiny dict, ~0.15us. + + Returns ``(NT, chunk_offsets, kernel_cu_seqlens, N)``. + """ + cache_key = (BT, num_decodes, num_decode_tokens) + cache = getattr(cu_seqlens, _PROLOGUE_ATTR, None) + if cache is None: + cache = {} + try: + object.__setattr__(cu_seqlens, _PROLOGUE_ATTR, cache) + except (AttributeError, TypeError): + # Subclass disallows ad-hoc attrs; fall through to recomputing + # (still correct, just slower). + cache = None + if cache is not None: + hit = cache.get(cache_key) + if hit is not None: + return hit + + chunk_offsets = prepare_chunk_offsets( + cu_seqlens, BT, num_decodes, num_decode_tokens + ) + NT = prepare_num_chunks(cu_seqlens, BT, num_decodes, num_decode_tokens) + kernel_cu_seqlens = prepare_rebased_cu_seqlens( + cu_seqlens, num_decodes, num_decode_tokens + ) + N = len(kernel_cu_seqlens) - 1 + result = (NT, chunk_offsets, kernel_cu_seqlens, N) + if cache is not None: + cache[cache_key] = result + return result + + +def _get_dummy(device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """Return a shared 1-element scalar tensor for null-arg launches. + + Used to satisfy ``@flyc.jit``'s no-``None`` requirement on the batched + path where the kernel reads neither ``cu_seqlens`` nor ``chunk_offsets`` + (and the various ``use_*`` guards in the kernel body also skip the + corresponding loads). Returning the same tensor avoids allocator and + dispatch overhead on every forward. + """ + key = (device.index if device.type == "cuda" else -1, dtype) + out = _dummy_tensors.get(key) + if out is None: + out = torch.empty(1, device=device, dtype=dtype) + _dummy_tensors[key] = out + return out + def _legal_bv_candidates(V: int) -> list[int]: return [c for c in _BV_CANDIDATES if c <= V and V % c == 0] @@ -68,58 +413,109 @@ def _select_bv_for_grid(*, H: int, V: int, N: int, target_ctas: int) -> int: def _target_bv_for_shape( - *, H: int, Hg: int, T_flat: int, N: int, is_varlen: bool + *, H: int, Hg: int, T_flat: int, N: int, is_varlen: bool, + min_seqlen: int | None = None, ) -> int | None: - """Return the calibrated BV regime before legality/grid adjustment.""" + """Return the calibrated BV regime before legality/grid adjustment. + + Calibration scope (gfx950, V=128, BT=64, is_varlen=True): + * H==32, Hg==16 : N in {2, 3} swept on T_flat in [2000, 25000]. + Outside that (T_flat, N) cube the rule deliberately returns + ``None`` so the grid-fill default applies -- matches the + pre-20260604 behavior. The N=2 / N=3 carve-outs are the only + new behavior; everything else is preserved exactly. + * H==16 : 32k-context many-seq carve-out (unchanged). + + Args: + min_seqlen: smallest segment length in the (varlen) batch, i.e. + ``min(cu_seqlens[1:] - cu_seqlens[:-1])``. Optional; some + sub-rules (e.g. the N=2 "balanced-split" carve-out below) + need this to distinguish "head ~= T/2" from "head << T". + When None, those sub-rules fall through to the previous + T_flat-only logic. + + If you extend this rule, please keep: + (a) every return statement guarded by an explicit T_flat range + you actually measured; + (b) the "no data -> return None" fallthrough at the end of each + branch so untested combos can't silently regress. + + All carve-outs below were swept on gfx950 only, so they are skipped on + other arches (``_IS_GFX950`` guard) -- non-gfx950 falls through to the + generic CU-fill default, preserving the pre-calibration behavior. + """ + if not _IS_GFX950: + return None if is_varlen and H == 32 and Hg == 16: - if N == 2 and 11000 <= T_flat < 15000: - return 16 - if N == 3 and not (10000 <= T_flat < 12000 or 20000 <= T_flat < 25000): - return 64 + if N == 2: + # Calibrated range: T_flat in [2000, 25000]. Two flips + # measured on H=32/Hg=16/V=128/gfx950 (notes: + # _bv_sweep_n2_20260604): + # T_flat < 8000 : BV=32 (grid 256, exactly one wave) + # 8000 <= T_flat < 12000 : BV=16 + # 12000 <= T_flat < 13000 : BV=32 (narrow tail-fit window) + # 13000 <= T_flat <= 25000 : BV=16 + # T_flat outside [2000, 25000] is NOT covered; fall through. + # + # Balanced-split carve-out (bench20260604_051030, n=2 T~16k + # cluster, 134 cases): when both segments are roughly + # balanced (min_seqlen >= 6300), BV=32 wins by 17-76us per + # case across T_flat in [12000, 20000]. The earlier + # T_flat-only rule misses this because it assumed n=2 with + # T>=13000 was always "long single-segment dominant"; large + # min_seqlen indicates the opposite (two comparable runs). + # The (T_flat, min_seqlen) window was sweep-validated to + # avoid any regression vs the T_flat-only rule on 44 + # measured (T_flat, head) points (notes: _bv_sweep_n2_balanced). + if ( + 12000 <= T_flat <= 20000 + and min_seqlen is not None + and min_seqlen >= 6300 + ): + return 32 + if 2000 <= T_flat <= 25000: + if T_flat < 8000: + return 32 + if 12000 <= T_flat < 13000: + return 32 + return 16 + # else: untested range, fall through to default + elif N == 3: + # Calibrated range: T_flat in [8000, 30000]. Across this + # whole range BV=32 (grid=N*H*ceil(V/32) = 384, ~1.5 waves + # on 256 CUs) measured 22-95us faster than the prior BV=64 + # / grid-fill choice, including the bench20260603 cluster + # T~=16384 cu=[0,head,head+10000,T] (~85us per case, 200+ + # cases). T_flat outside this range is NOT covered; fall + # through. + # + # Balanced-split carve-out (notes: _bv_sweep_n3_balanced, + # 20 measured (T_flat, min_seg, max_seg) points): when the + # smallest segment is >= 3000 the three segments are large + # enough that BV=64 (grid 192, exactly 0.75 wave on 256 CUs) + # wins by 11-74us across T_flat in [10000, 25000]. The + # earlier rule missed this because the original calibration + # only swept skewed splits (head << T) where one tiny + # segment makes BV=64 padding-bound. Validated decision + # boundary: min_seg <= 2384 -> BV=32 still wins, min_seg + # >= 3000 -> BV=64 wins; no regression observed on the + # skewed-split cluster (which has min_seg << 3000). + if ( + T_flat >= 10000 + and min_seqlen is not None + and min_seqlen >= 3000 + ): + return 64 + if 8000 <= T_flat <= 30000: + return 32 + # else: untested range, fall through to default + # N==1 and N>=4 are NOT touched by this branch -- the original + # behavior (return None -> grid-fill default) is preserved. if is_varlen and H == 16 and T_flat >= 32768 and N >= 7: return 64 return None -def _lookup_tuned_bv( - dtype_str, - K, - V, - BT, - H, - Hg, - T_flat, - N, - use_g, - use_gk, - use_h0, - store_fs, - save_vn, - is_varlen, - wu_contig, -): - """Select ``BV`` with the rule-based grid/CU heuristic.""" - del ( - dtype_str, - K, - BT, - use_g, - use_gk, - use_h0, - store_fs, - save_vn, - wu_contig, - ) - return _heuristic_bv( - H=H, - Hg=Hg, - V=V, - T_flat=T_flat, - N=N, - is_varlen=is_varlen, - ) - - def _heuristic_bv( *, H: int, @@ -128,8 +524,27 @@ def _heuristic_bv( T_flat: int, N: int, is_varlen: bool, + min_seqlen: int | None = None, + device_index: int | None = None, + dtype_str: str | None = None, + K: int | None = None, + BT: int | None = None, ) -> int: - """Pick a sensible BV for the requested shape. Pure function: no IO, no state. + """Pick a sensible BV for the requested shape. + + Selection order: + 1. Exact-match best BV from the offline-tuned csv + (``chunk_gdn_h_tuned.csv``), when the full key is available and the + row's BV is legal for this ``V``. This gives the measured optimum + for shapes that were actually swept. + 2. Otherwise the rule-based heuristic below (CTA/CU grid-fill plus the + trace-calibrated carve-outs), which generalizes to shapes the sparse + csv does not cover. + + The function performs a cheap one-time csv parse on first call and is + otherwise pure w.r.t. its scalar arguments; the result is consumed only + on the ``_build_plan`` cold path (one call per unique plan key), so the + csv parse / lookup never touches the hot launch path. Rules calibrated against a 27-point sweep matrix on gfx950 (20 in-csv shapes + 7 csv-uncovered probes). The 27 points span H in @@ -141,7 +556,8 @@ def _heuristic_bv( reduces per-CTA overhead; smaller BV exposes more CTAs for CU utilization. - * ``is_varlen=False`` -- target one wave of CTAs over gfx950's 256 CUs. + * ``is_varlen=False`` -- target one wave of CTAs over the device's CUs + (live ``multi_processor_count``; 256 on gfx950). * ``is_varlen=True`` -- the target grid depends on (H, T_local) jointly: H <= 8: @@ -177,15 +593,34 @@ def _heuristic_bv( V (rare: V<16 or V not divisible by 16), falls back to the largest legal candidate, then finally to ``_DEFAULT_BV``. """ - target_bv = _target_bv_for_shape( - H=H, Hg=Hg, T_flat=T_flat, N=N, is_varlen=is_varlen + # 1. csv-best exact match (preferred). Only honoured when the row's BV is + # legal for this V; an out-of-range / non-divisor BV in a hand-edited + # csv silently falls through to the rule rather than launching a bad + # grid. + csv_bv = _lookup_csv_bv( + dtype_str=dtype_str, K=K, BT=BT, + H=H, Hg=Hg, V=V, T_flat=T_flat, N=N, is_varlen=is_varlen, ) - target_ctas = ( - _grid_ctas(H=H, V=V, N=N, BV=target_bv) if target_bv is not None else 256 + if csv_bv is not None and csv_bv in _legal_bv_candidates(V): + return csv_bv + + # 2. rule-based fallback (generalizes beyond the sparse csv). + target_bv = _target_bv_for_shape( + H=H, Hg=Hg, T_flat=T_flat, N=N, is_varlen=is_varlen, + min_seqlen=min_seqlen, ) + if target_bv is not None: + target_ctas = _grid_ctas(H=H, V=V, N=N, BV=target_bv) + else: + # Generic default: target one wave of CTAs over the device's CUs. + # Use the live CU count (256 on gfx950, differs on other arches) + # rather than a hardcoded gfx950 value. + idx = device_index if device_index is not None else -1 + target_ctas = _cu_count(idx) return _select_bv_for_grid(H=H, V=V, N=N, target_ctas=target_ctas) +@functools.lru_cache(maxsize=None) def _get_or_compile( K, V, @@ -203,7 +638,119 @@ def _get_or_compile( state_bf16=False, g_log2_scaled=False, ): - cache_key = ( + """Compile (and cache) the K5 kernel for one compile-time config. + + Cached via ``lru_cache`` keyed on the full compile-time constant set, + mirroring the gemm/moe/gdr_decode flydsl ops. ``maxsize=None`` because + the number of distinct configs is naturally bounded by the compile-time + constant combinations. + """ + return compile_chunk_gated_delta_h( + K=K, + V=V, + BT=BT, + BV=BV, + H=H, + Hg=Hg, + USE_G=use_g, + USE_GK=use_gk, + USE_INITIAL_STATE=use_h0, + STORE_FINAL_STATE=store_fs, + SAVE_NEW_VALUE=save_vn, + IS_VARLEN=is_varlen, + WU_CONTIGUOUS=wu_contig, + STATE_DTYPE_BF16=state_bf16, + G_IS_LOG2_SCALED=g_log2_scaled, + ) + + +def _resolve_state_dtype(initial_state, state_dtype): + """Mirror the legacy state-dtype resolution. Cheap; runs every call.""" + if initial_state is not None: + resolved = initial_state.dtype + if state_dtype is not None and state_dtype != resolved: + raise ValueError( + f"state_dtype={state_dtype} conflicts with " + f"initial_state.dtype={initial_state.dtype}; pass them " + f"consistently or omit state_dtype." + ) + elif state_dtype is not None: + resolved = state_dtype + else: + resolved = torch.float32 + if resolved not in (torch.float32, torch.bfloat16): + raise ValueError( + f"SSM state dtype must be float32 or bfloat16, got {resolved}." + ) + return resolved + + +def _build_plan( + *, + k, + w, + u, + cu_seqlens, + chunk_size, + use_g, + use_gk, + use_h0, + output_final_state, + save_new_value, + g_log2_scaled, + state_bf16, + resolved_state_dtype, + num_decodes, + num_decode_tokens, + wu_contiguous, +): + """Pre-compute every shape/flag-derived product the hot path needs. + + Called once per unique ``_plan_key``; the returned tuple is stored + verbatim in ``_plan_cache``. All fields are immutable (ints, tensors + with stable identity, the compiled ``launch_fn``), so reuse across + forwards is safe as long as the plan key is honored. + """ + B, T, _Hg, K = k.shape + H = w.shape[1] + V = u.shape[-1] + T_flat = w.shape[2] + Hg = _Hg + BT = chunk_size + + assert K <= 256 + + if cu_seqlens is None: + N = B + NT = triton.cdiv(T, BT) + chunk_offsets = None + kernel_cu_seqlens = None + is_varlen = False + min_seqlen = None + else: + NT, chunk_offsets, kernel_cu_seqlens, N = _resolve_prologue( + cu_seqlens, BT, num_decodes, num_decode_tokens + ) + is_varlen = True + # Smallest segment length, used by ``_target_bv_for_shape``'s + # N=2 "balanced-split" carve-out. ``_build_plan`` is a cold path + # (one call per unique ``_plan_key``; subsequent forwards on the + # same shape hit ``_plan_cache``), so the host-side .min() + + # .item() sync (~5us) is paid once per shape, not per forward. + if N >= 1: + seg_lens = kernel_cu_seqlens[1:] - kernel_cu_seqlens[:-1] + min_seqlen = int(seg_lens.min().item()) + else: + min_seqlen = None + + BV = _heuristic_bv( + H=H, Hg=Hg, V=V, T_flat=T_flat, N=N, is_varlen=is_varlen, + min_seqlen=min_seqlen, + device_index=k.device.index if k.device.type == "cuda" else -1, + dtype_str=str(k.dtype), K=K, BT=BT, + ) + + launch_fn = _get_or_compile( K, V, BT, @@ -213,75 +760,56 @@ def _get_or_compile( use_g, use_gk, use_h0, - store_fs, - save_vn, + output_final_state, + save_new_value, is_varlen, - wu_contig, - state_bf16, - g_log2_scaled, + wu_contiguous, + state_bf16=state_bf16, + g_log2_scaled=g_log2_scaled, ) - if cache_key not in _compiled_kernels: - _compiled_kernels[cache_key] = compile_chunk_gated_delta_h( - K=K, - V=V, - BT=BT, - BV=BV, - H=H, - Hg=Hg, - USE_G=use_g, - USE_GK=use_gk, - USE_INITIAL_STATE=use_h0, - STORE_FINAL_STATE=store_fs, - SAVE_NEW_VALUE=save_vn, - IS_VARLEN=is_varlen, - WU_CONTIGUOUS=wu_contig, - STATE_DTYPE_BF16=state_bf16, - G_IS_LOG2_SCALED=g_log2_scaled, - ) - return _compiled_kernels[cache_key] + fp32_dummy = _get_dummy(k.device, torch.float32) + int32_dummy = _get_dummy(k.device, torch.int32) + cu_arg = ( + _as_int32(kernel_cu_seqlens) + if kernel_cu_seqlens is not None + else int32_dummy + ) + co_arg = ( + _as_int32(chunk_offsets) if chunk_offsets is not None else int32_dummy + ) + stream = _current_stream(k.device) -def _launch_kernel( - launch_fn, - BV, - V, - N, - H, - k, - u, - w, - vn_arg, - g_arg, - gk_arg, - h, - h0_arg, - ht_arg, - cu_arg, - co_arg, - T, - T_flat, - stream, -): grid_v = triton.cdiv(V, BV) grid_nh = N * H - launch_fn( - k, - u, - w, - vn_arg, - g_arg, - gk_arg, - h, - h0_arg, - ht_arg, + + # Output-buffer shapes/dtypes (sizes are ints, allocator is called per + # forward against these on the hot path). + h_shape = (B, NT, H, V, K) + vn_shape = (B, H, T_flat, V) + vn_dtype = u.dtype + fs_shape = (N, H, V, K) if output_final_state else None + fs_dtype = resolved_state_dtype if output_final_state else None + + # Tuple (not dict) so the hot path uses constant-index access instead of + # string hashing for every field. + return ( + launch_fn, + fp32_dummy, cu_arg, co_arg, + stream, T, T_flat, N, grid_v, grid_nh, - stream, + h_shape, + vn_shape, + vn_dtype, + fs_shape, + fs_dtype, + save_new_value, ) @@ -348,185 +876,163 @@ def chunk_gated_delta_rule_fwd_h_flydsl( BV-tile selection is rule-based. ``chunk_gdn_h_tuned.csv`` remains an AOT seed list for pre-compilation, but runtime BV selection does not read it. """ - # Layout is fixed to head-major contiguous (matches Triton VK wrapper). - wu_contiguous = True + # Hot path overview: every shape/flag-derived product (BV, launch_fn, + # grid dims, int32-view offsets, output-buffer shapes, ...) is packed + # into a per-shape ``plan`` tuple stored in ``_plan_cache``. A repeat + # call on a previously seen shape reduces to: one dict lookup, three + # ``new_empty`` calls and the actual launcher. + # + # Two-level cache: + # + # L1: a single (validate_tuple, plan) attached to ``k`` itself. + # Hit cost: one ``getattr`` + one tuple ``==`` (~0.4us). + # Validity: identity of (w, u, g, gk, h0, cu_seqlens) plus all + # flags. Caller code that drives a stable shape with stable + # tensor objects (KV-cache style decoding loops, prefill warm + # loops) hits L1 100% of the time after warmup. + # + # L2: shape/flag-keyed plan cache. Used when L1 misses. The key is + # a packed-flags int + tensor shapes + dtypes + cu_seqlens id; + # this still works when callers swap tensor objects between + # forwards (e.g. autograd reallocation) as long as the *shapes* + # are stable. + is_varlen = cu_seqlens is not None + use_g = g is not None + use_gk = gk is not None + use_h0 = initial_state is not None g_log2_scaled = bool(use_exp2) - # SSM state dtype: derived from ``initial_state.dtype`` when provided, - # otherwise from ``state_dtype`` kwarg, otherwise default f32 (matches - # the legacy behaviour). Only ``torch.float32`` and ``torch.bfloat16`` - # are supported by the kernel. + # state_bf16 derivation, inlined and cache-friendly. The full + # ``_resolve_state_dtype`` (with its raise paths for bad dtypes / + # conflicts) runs only on L2 miss; on the hot path we only do the + # ``is None`` checks and a dtype ``==`` compare. if initial_state is not None: - resolved_state_dtype = initial_state.dtype - if state_dtype is not None and state_dtype != resolved_state_dtype: - raise ValueError( - f"state_dtype={state_dtype} conflicts with " - f"initial_state.dtype={initial_state.dtype}; pass them consistently " - f"or omit state_dtype." - ) + _state_dtype = initial_state.dtype elif state_dtype is not None: - resolved_state_dtype = state_dtype + _state_dtype = state_dtype else: - resolved_state_dtype = torch.float32 - if resolved_state_dtype not in (torch.float32, torch.bfloat16): - raise ValueError( - f"SSM state dtype must be float32 or bfloat16, got {resolved_state_dtype}." - ) - state_bf16 = resolved_state_dtype == torch.bfloat16 + _state_dtype = torch.float32 + state_bf16 = _state_dtype is torch.bfloat16 - B, T, Hg, K = k.shape - BT = chunk_size - - H = w.shape[1] - V = u.shape[-1] - T_flat = w.shape[2] + flags = _pack_flags( + use_g, use_gk, use_h0, output_final_state, save_new_value, + g_log2_scaled, state_bf16, is_varlen, + chunk_size, num_decodes, num_decode_tokens, + ) - if cu_seqlens is None: - N, NT, chunk_offsets = B, triton.cdiv(T, BT), None - kernel_cu_seqlens = None + # L1: per-tensor fast path. validate-tuple compares identities of + # every co-input that could change BETWEEN forwards on the same plan + # (shapes/dtypes are implicitly fixed because k itself is fixed). + fast = getattr(k, _FAST_PLAN_ATTR, None) + fast_key = (flags, id(w), id(u), id(g), id(gk), + id(initial_state), id(cu_seqlens), id(state_dtype)) + if fast is not None and fast[0] == fast_key: + plan = fast[1] else: - # Pass the ORIGINAL (cache-stable) cu_seqlens + the decode ints into - # the cached prologue helpers. They all key on the original tensor's - # identity, so chunk_offsets / NT / the rebased kernel cu_seqlens are - # computed ONCE per (cu_seqlens_id, BT, num_decodes, num_decode_tokens) - # tuple and every subsequent forward is a pure cache hit -> no - # per-forward D2H. (Passing a freshly-rebased tensor instead would key - # the offset/num-chunk caches on an unstable identity and re-fire the - # .tolist()/int() syncs every call.) - chunk_offsets = prepare_chunk_offsets( - cu_seqlens, BT, num_decodes, num_decode_tokens + # L2: shape/flag-keyed plan cache. Resolve state_dtype properly + # here so caller-facing errors are NOT swallowed by L1's identity + # match (since L1 only hits when initial_state identity matches, + # any new tensor with a bad dtype combination forces L2 and gets + # validated). + resolved_state_dtype = _resolve_state_dtype(initial_state, state_dtype) + plan_key = ( + k.shape, w.shape, u.shape, + k.dtype, u.dtype, + k.device.index if k.device.type == "cuda" else -1, + flags, + id(cu_seqlens) if cu_seqlens is not None else 0, ) - NT = prepare_num_chunks(cu_seqlens, BT, num_decodes, num_decode_tokens) - # Rebased kernel-facing cu_seqlens (matches the pre-sliced prefill - # data). N is the prefill sequence count (len() is a shape read, no - # sync). - kernel_cu_seqlens = prepare_rebased_cu_seqlens( - cu_seqlens, num_decodes, num_decode_tokens - ) - N = len(kernel_cu_seqlens) - 1 - - assert K <= 256 - - h = k.new_empty(B, NT, H, V, K) - final_state = ( - k.new_empty(N, H, V, K, dtype=resolved_state_dtype) - if output_final_state - else None - ) - v_new_buf = k.new_empty(B, H, T_flat, V, dtype=u.dtype) - v_new = v_new_buf if save_new_value else None + plan = _plan_cache.get(plan_key) + if plan is None: + if len(_plan_cache) >= _PLAN_CACHE_MAX: + _plan_cache.clear() + plan = _build_plan( + k=k, w=w, u=u, cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + use_g=use_g, use_gk=use_gk, use_h0=use_h0, + output_final_state=output_final_state, + save_new_value=save_new_value, + g_log2_scaled=g_log2_scaled, + state_bf16=state_bf16, + resolved_state_dtype=resolved_state_dtype, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + wu_contiguous=True, + ) + _plan_cache[plan_key] = plan + # Stash on k so the next forward with the same co-input identities + # bypasses the L2 lookup entirely. Best-effort: tensor subclasses + # that disallow ad-hoc attrs simply skip the L1 install. + try: + object.__setattr__(k, _FAST_PLAN_ATTR, (fast_key, plan)) + except (AttributeError, TypeError): + pass - dummy = torch.empty(1, device=k.device, dtype=torch.float32) + ( + launch_fn, + fp32_dummy, + cu_arg, + co_arg, + stream, + T_plan, + T_flat_plan, + N_plan, + grid_v, + grid_nh, + h_shape, + vn_shape, + vn_dtype, + fs_shape, + fs_dtype, + save_vn, + ) = plan - # G layout is fixed to head-major [B, H, T_flat] (matches Triton VK / - # HIP K5). The kernel reads ``g`` with stride-1 along the T dim; require - # the caller to provide a contiguous head-major tensor. - if g is not None: - assert g.is_contiguous(), ( + # G contiguity guard. The shape check is omitted on the hot path: + # ``T_flat`` is part of the plan key (via w.shape), so a mismatched + # ``g.shape[-1]`` against a previously seen plan can only happen if + # the caller breaks the documented [B, H, T_flat] contract -- in + # which case strides will diverge and ``is_contiguous()`` is enough + # to catch the common modes (transposed view, slice). Keeping just + # ``is_contiguous`` keeps the safety net for ~50ns instead of ~1us. + if g is not None and not g.is_contiguous(): + raise AssertionError( "FlyDSL K5: ``g`` must be contiguous (head-major [B, H, T_flat] " f"or [H, T_flat]); got strides={g.stride()}, shape={tuple(g.shape)}." ) - assert g.shape[-1] == T_flat, ( - f"FlyDSL K5: ``g.shape[-1]`` must equal T_flat={T_flat}, " - f"got g.shape={tuple(g.shape)}." - ) - assert g.shape[-2] == H, ( - f"FlyDSL K5: ``g.shape[-2]`` must equal H={H}, " - f"got g.shape={tuple(g.shape)}." - ) - g_arg = g if g is not None else dummy - # Mirror the Triton VK wrapper: when ``use_exp2=True`` the K5 kernel - # interprets ``gk`` in log2 space, so pre-scale by log2(e) here. The - # kernel-side ``_fast_exp`` for ``gk`` is shared with the ``g`` path; - # ``g`` itself must already be log2-scaled by the K1+K2 producer when - # use_exp2 is on. + # gk pre-scaling: still per-call work (allocates a new tensor). Cannot + # be cached without aliasing across forwards; an upstream producer + # change to emit log2-space gk directly would eliminate this entirely. if gk is not None: gk = gk.contiguous() if g_log2_scaled: gk = gk * _RCP_LN2 - gk_arg = gk if gk is not None else dummy - h0_arg = initial_state if initial_state is not None else dummy - ht_arg = final_state if final_state is not None else dummy - vn_arg = v_new_buf - # cu_arg / co_arg are the kernel-facing (rebased) offsets, narrowed to - # int32. `.to(torch.int32)` is a device-to-device cast (no host sync); the - # resulting fresh objects are consumed only by the kernel launch, so their - # identity does not matter for the @tensor_cache helpers above. - cu_arg = ( - kernel_cu_seqlens.to(torch.int32) - if kernel_cu_seqlens is not None - else dummy.to(torch.int32) - ) - co_arg = ( - chunk_offsets.to(torch.int32) - if chunk_offsets is not None - else dummy.to(torch.int32) - ) - stream = torch.cuda.current_stream() - use_g = g is not None - use_gk = gk is not None - use_h0 = initial_state is not None - is_varlen = cu_seqlens is not None - - # Resolve BV from the rule-based grid/CU heuristic. - BV = _lookup_tuned_bv( - dtype_str=str(k.dtype), - K=K, - V=V, - BT=BT, - H=H, - Hg=Hg, - T_flat=T_flat, - N=N, - use_g=use_g, - use_gk=use_gk, - use_h0=use_h0, - store_fs=bool(output_final_state), - save_vn=bool(save_new_value), - is_varlen=is_varlen, - wu_contig=wu_contiguous, + h = k.new_empty(h_shape) + v_new_buf = k.new_empty(vn_shape, dtype=vn_dtype) + final_state = ( + k.new_empty(fs_shape, dtype=fs_dtype) if fs_shape is not None else None ) - launch_fn = _get_or_compile( - K, - V, - BT, - BV, - H, - Hg, - use_g, - use_gk, - use_h0, - output_final_state, - save_new_value, - is_varlen, - wu_contiguous, - state_bf16=state_bf16, - g_log2_scaled=g_log2_scaled, - ) - _launch_kernel( - launch_fn, - BV, - V, - N, - H, + launch_fn( k, u, w, - vn_arg, - g_arg, - gk_arg, + v_new_buf, + g if g is not None else fp32_dummy, + gk if gk is not None else fp32_dummy, h, - h0_arg, - ht_arg, + initial_state if initial_state is not None else fp32_dummy, + final_state if final_state is not None else fp32_dummy, cu_arg, co_arg, - T, - T_flat, + T_plan, + T_flat_plan, + N_plan, + grid_v, + grid_nh, stream, ) - return h, v_new, final_state + return h, (v_new_buf if save_vn else None), final_state diff --git a/op_tests/flydsl_tests/test_flydsl_linear_attention_prefill.py b/op_tests/flydsl_tests/test_flydsl_linear_attention_prefill.py index 31b6fb7b26..a6a16c8537 100644 --- a/op_tests/flydsl_tests/test_flydsl_linear_attention_prefill.py +++ b/op_tests/flydsl_tests/test_flydsl_linear_attention_prefill.py @@ -4,8 +4,10 @@ """Unit tests for FlyDSL Linear Attention Prefill (chunk_gated_delta_h) regressions. Usage: - HIP_VISIBLE_DEVICES=7 pytest -sv aiter/ops/flydsl/test_flydsl_linear_attention_prefill.py::TestPerformance -s - HIP_VISIBLE_DEVICES=7 python -m pytest aiter/ops/flydsl/test_flydsl_linear_attention_prefill.py::TestPerformance -k "varlen-16k-aws" -v -s + rm -rf ~/.triton/cache + export GATED_DELTA_RULE_TRITON_AUTOTUNE=1 + HIP_VISIBLE_DEVICES=7 pytest -sv op_tests/flydsl_tests/test_flydsl_linear_attention_prefill.py::TestPerformance -s + HIP_VISIBLE_DEVICES=7 python -m pytest op_tests/flydsl_tests/test_flydsl_linear_attention_prefill.py::TestPerformance -k "varlen-16k-aws" -v -s """ from __future__ import annotations @@ -181,6 +183,30 @@ class PrefillGroup: # length 1024. Preserving that behavior is what keeps the # varlen path's per-case shape unchanged across this refactor. max_num_batched_tokens: object = None + # Optional "trace-derived 3-segment" expansion knob. When set, each + # expanded case overrides ``_build_context_lens`` with the explicit + # 3-segment layout ``[head, mid_seqlen, full_prompt_len - head - mid_seqlen]``, + # i.e. cu_seqlens = [0, head, head + mid_seqlen, full_prompt_len]. + # This reproduces the worst K5 regression family found in bench + # results 20260603 (n=3, T ~= 16384, middle segment == 10000): the + # K5 kernel exhibits a near-constant ~543us cost across this whole + # cluster regardless of head_seqlen, while triton K5 varies with the + # head split between ~460-495us. Sweeping head_seqlens lets us probe + # the kernel's sensitivity (or lack thereof) to the head boundary. + # Group is materialised as the (tps x full_prompt_lens x head_seqlens) + # Cartesian product when this is not None. + head_seqlens: object = None # list[int] | None + mid_seqlen: int = 10000 + # Number of segments per expanded case when ``head_seqlens`` is set: + # num_segments=3 (default): context_lens = [head, mid_seqlen, full_len-head-mid_seqlen] + # -> cu_seqlens = [0, head, head+mid_seqlen, full_len] (n=3) + # num_segments=2 : context_lens = [head, full_len-head] + # -> cu_seqlens = [0, head, full_len] (n=2) + # ``mid_seqlen`` is ignored in this mode; the tail length is whatever + # remains after ``head``. Used to cover the n=2 T=16384 regression + # clusters (head near 6400 / 8192 / 9912 / 10000) found in the + # bench_gdr 20260604 trace. + num_segments: int = 3 def expand_groups(groups): @@ -194,23 +220,75 @@ def expand_groups(groups): mnbt = 32768 # PrefillArgs dataclass default else: mnbt = g.max_num_batched_tokens - out.append( - PrefillArgs( - K=g.K, - V=g.V, - Hk=g.Hk, - Hv=g.Hv, - tp=tp, - full_prompt_len=full_len, - model_name=g.model_name, - BT=g.BT, - max_num_batched_tokens=mnbt, - dtype=g.dtype, - is_varlen=g.is_varlen, - output_final_state=g.output_final_state, - ssm_state_dtype=g.ssm_state_dtype, + + # head_seqlens=None : preserve the original "equal split via + # _build_context_lens" behavior. Otherwise materialise one + # PrefillArgs per (tp, full_len, head) triple with an + # explicit 3-segment cu_seqlens layout + # [head, mid_seqlen, full_len - head - mid_seqlen]. + if g.head_seqlens is None: + out.append( + PrefillArgs( + K=g.K, V=g.V, Hk=g.Hk, Hv=g.Hv, + tp=tp, + full_prompt_len=full_len, + model_name=g.model_name, + BT=g.BT, + max_num_batched_tokens=mnbt, + dtype=g.dtype, + is_varlen=g.is_varlen, + output_final_state=g.output_final_state, + ssm_state_dtype=g.ssm_state_dtype, + ) ) - ) + else: + for head in g.head_seqlens: + if g.num_segments == 2: + tail = full_len - head + if tail <= 0: + raise ValueError( + f"head_seqlens (num_segments=2) produced " + f"non-positive tail ({tail}) for " + f"group={g.model_name!r} " + f"full_prompt_len={full_len} head={head}." + ) + context_lens = [head, tail] + tag = f"head{head}_tail{tail}" + elif g.num_segments == 3: + tail = full_len - head - g.mid_seqlen + if tail <= 0: + raise ValueError( + f"head_seqlens (num_segments=3) produced " + f"non-positive tail ({tail}) for " + f"group={g.model_name!r} " + f"full_prompt_len={full_len} head={head} " + f"mid_seqlen={g.mid_seqlen}. Drop this " + f"(full_len, head) combo or raise " + f"full_prompt_len." + ) + context_lens = [head, g.mid_seqlen, tail] + tag = f"head{head}_mid{g.mid_seqlen}" + else: + raise ValueError( + f"num_segments={g.num_segments} unsupported; " + f"only 2 or 3 are implemented." + ) + out.append( + PrefillArgs( + K=g.K, V=g.V, Hk=g.Hk, Hv=g.Hv, + tp=tp, + full_prompt_len=full_len, + model_name=g.model_name, + BT=g.BT, + max_num_batched_tokens=mnbt, + dtype=g.dtype, + is_varlen=g.is_varlen, + output_final_state=g.output_final_state, + ssm_state_dtype=g.ssm_state_dtype, + context_lens=context_lens, + trace_tag=tag, + ) + ) return out @@ -254,7 +332,6 @@ def expand_groups(groups): model_name="varlen-16k-aws", Hv=32, tps=[1], - # full_prompt_lens=[1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000], full_prompt_lens=[1000, 5000, 10000], max_num_batched_tokens=16384, ), @@ -262,39 +339,40 @@ def expand_groups(groups): model_name="varlen-32k-aws", Hv=32, tps=[1], - # full_prompt_lens=[1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000], full_prompt_lens=[1000, 5000, 10000], max_num_batched_tokens=32768, ), + PrefillGroup( + model_name="flydsl-k5-n1", + Hv=32, + tps=[1], + full_prompt_lens=[5000, 10000], + max_num_batched_tokens="full_prompt_len", + ), + PrefillGroup( + model_name="flydsl-k5-n3-mid10k", + Hv=32, + tps=[1], + full_prompt_lens=[16384], + max_num_batched_tokens=16384, + head_seqlens=[5, 10, 65, 704, 936, 1820, 4467, 5508], + mid_seqlen=10000, + ), + PrefillGroup( + model_name="flydsl-k5-n2-16k", + Hv=32, + tps=[1], + full_prompt_lens=[16384], + max_num_batched_tokens=16384, + head_seqlens=[4000, 6396, 8192, 9912, 10000], + num_segments=2, + ), ] PREFILL_PARAMS = expand_groups(_PREFILL_GROUPS) -# Perf-test parametrization is identical to the correctness one; trace- -# derived shapes have been removed from this file. -PERF_PARAMS = list(PREFILL_PARAMS) -PERF_TEST_IDS = [repr(p) for p in PERF_PARAMS] - - -# Mirror every base shape with a bf16-SSM-state variant. The bf16 vs f32 -# kernel paths only differ in two ``if const_expr`` branches: -# - h0 load (gated by USE_INITIAL_STATE) -# - ht store (gated by STORE_FINAL_STATE) -# The bf16 mirror keeps ``output_final_state`` from the base shape, so: -# - ``_nofs`` shapes (use_h0=True, store_fs=False) cover the h0 load path -# - default shapes (use_h0=True, store_fs=True) cover both paths -# Only ``(use_h0=False, store_fs=False)`` would generate IR identical to -# the f32 path; none of the current PREFILL_PARAMS hits that combo, so we -# do not filter here. If you add such a case later, gate the mirror with -# ``if _base.output_final_state or _make_inputs(...) provides h0``. -# NOTE: bf16 SSM-state mirrors disabled for focused perf profiling. -# PREFILL_PARAMS.extend( -# [ -# _dataclass_replace(_base, ssm_state_dtype=torch.bfloat16) -# for _base in list(PREFILL_PARAMS) -# ] -# ) +PREFILL_TEST_IDS = [repr(p) for p in PREFILL_PARAMS] # -- bf16 SSM-state params (paired with TestStateDtypeBF16 below) ------ @@ -560,8 +638,6 @@ def _bench_fn(fn, *args, **kwargs): # -- Correctness tests --------------------------------------------------- -PREFILL_TEST_IDS = [repr(p) for p in PREFILL_PARAMS] - def _assert_k5_outputs_match_ref( h_out, @@ -971,68 +1047,16 @@ def triton_origin_opt_launch(): us_vllm / us_fly if (us_fly > 0 and us_vllm == us_vllm) else float("nan") ) - # bench333 cases carry trace-derived structural features (head/mid/ - # tail seqlen, log_count); compute these once so the summary table - # can display them in a bench333-specific sub-table. Non-bench333 - # rows leave these fields at their defaults (None / 0). - head_seqlen = mid_seqlen = tail_seqlen = 0 - log_count = 0 - if args.context_lens is not None and args.model_name == "prefill-bench333": - lens = list(args.context_lens) - if len(lens) == 1: - head_seqlen, tail_seqlen = int(lens[0]), 0 - elif len(lens) == 2: - head_seqlen, tail_seqlen = int(lens[0]), int(lens[1]) - elif len(lens) >= 3: - mids = lens[1:-1] - head_seqlen = int(lens[0]) - tail_seqlen = int(lens[-1]) - mid_seqlen = int(mids[0]) if all(m == mids[0] for m in mids) else 0 - # trace_tag is set to "cnt{log_count}" by _build_bench407_params. - if args.trace_tag.startswith("cnt"): - try: - log_count = int(args.trace_tag[3:]) - except ValueError: - log_count = 0 - - # For bench333 trace shapes the model name is the same string for - # all 333 rows ("prefill-bench333") which is useless in the per-row - # summary table. Use the trailing "T{T}_n{N}_cnt{log_count}" suffix - # of the pytest id instead, so each row's Model cell uniquely - # identifies the case. - if args.model_name == "prefill-bench333" and args.context_lens is not None: - n_seqs = len(args.context_lens) - T_total = sum(args.context_lens) - model_label = f"T{T_total}_n{n_seqs}_{args.trace_tag}" - else: - model_label = args.model_name or "-" - _perf_results.append( { - "Model": model_label, + "Model": args.model_name or "-", "TP": args.tp, - "K": args.K, - "V": args.V, "Hg": args.Hg, "H": args.H, "SeqLen": args.full_prompt_len, "T": total_tokens, "varlen": args.is_varlen, "final_st": args.output_final_state, - "state": "bf16" if args.ssm_state_dtype == torch.bfloat16 else "fp32", - # bench333-only fields (0 for main-table rows) - "T_prefill": total_tokens if args.model_name == "prefill-bench333" else 0, - "num_seqs_prefill": ( - len(args.context_lens) - if args.context_lens is not None - and args.model_name == "prefill-bench333" - else 0 - ), - "head_seqlen": head_seqlen, - "mid_seqlen": mid_seqlen, - "tail_seqlen": tail_seqlen, - "log_count": log_count, - # Perf columns (same for all rows) "FlyDSL_vk(us)": us_fly, "Triton_vk(us)": us_triton_vk, "Triton_origin_opt(us)": us_triton_origin_opt, @@ -1047,7 +1071,7 @@ def triton_origin_opt_launch(): class TestPerformance: """Kernel-only performance comparison: FlyDSL vs Triton opt_vk vs Triton opt3_kv.""" - @pytest.mark.parametrize("args", PERF_PARAMS, ids=PERF_TEST_IDS) + @pytest.mark.parametrize("args", PREFILL_PARAMS, ids=PREFILL_TEST_IDS) def test_perf_comparison(self, args: PrefillArgs): _run_perf_comparison(args) @@ -1056,24 +1080,7 @@ def _print_perf_table(): if not _perf_results: return - # Two column layouts: - # - main_cols : for hand-written PrefillGroup shapes (Qwen3.5-35B - # / 397B / varlen-fs / bench333-varlen) showing - # SeqLen + T + varlen + final_st. - # - bench_cols : for bench333 trace-derived shapes showing - # T_prefill + num_seqs_prefill + head/mid/tail - # + log_count instead. SeqLen / T are redundant - # here because T_prefill carries the same info. - perf_tail_cols = [ - ("FlyDSL", "FlyDSL_vk(us)", 8), - ("Tri_vk", "Triton_vk(us)", 8), - ("Tri_orig_opt", "Triton_origin_opt(us)", 12), - ("vLLM", "vLLM_vk(us)", 8), - ("fly/vk", "flydsl_vs_vk", 7), - ("fly/o_opt", "flydsl_vs_origin_opt", 9), - ("fly/vllm", "flydsl_vs_vllm", 8), - ] - main_cols = [ + cols = [ ("Model", "Model", 16), ("TP", "TP", 3), ("Hg", "Hg", 3), @@ -1082,113 +1089,41 @@ def _print_perf_table(): ("T", "T", 7), ("var", "varlen", 3), ("fs", "final_st", 3), - ] + perf_tail_cols - bench_cols = [ - ("Model", "Model", 22), - ("TP", "TP", 3), - ("Hg", "Hg", 3), - ("H", "H", 3), - ("T_prefill", "T_prefill", 9), - ("n_pref", "num_seqs_prefill", 6), - ("head", "head_seqlen", 6), - ("mid", "mid_seqlen", 6), - ("tail", "tail_seqlen", 6), - ("log_cnt", "log_count", 7), - ] + perf_tail_cols - - def _build_header_sep(cols): - header = " | ".join(display.rjust(width) for display, _, width in cols) - sep = "-+-".join("-" * width for _, _, width in cols) - return header, sep - - def _fmt_row(row, cols): - cells = [] - for display, key, width in cols: - val = row[key] - if isinstance(val, bool): - cells.append(("Y" if val else "N").rjust(width)) - elif isinstance(val, float): - if val != val: # NaN (e.g. vLLM column when vllm not installed) - cells.append("-".rjust(width)) - elif "_vs_" in key: - cells.append(f"{val:.2f}x".rjust(width)) - else: - cells.append(f"{val:.1f}".rjust(width)) - else: - cells.append(str(val).rjust(width)) - return " | ".join(cells) - - main_header, main_sep = _build_header_sep(main_cols) - bench_header, bench_sep = _build_header_sep(bench_cols) - border = "=" * max(len(main_header), len(bench_header)) - - # Bucket rows by SSM-state dtype and by main-vs-bench333. Keep each - # bucket's order consistent with ``_perf_results`` insertion so rows - # line up with the parametrize id order. bench333 trace rows are - # tagged by ``log_count > 0`` (main groups all leave log_count at 0). - def _split(rows): - main = [r for r in rows if r.get("log_count", 0) == 0] - bench = [r for r in rows if r.get("log_count", 0) > 0] - return main, bench - - rows_fp32_main, rows_fp32_bench = _split( - [r for r in _perf_results if r["state"] == "fp32"] - ) - rows_bf16_main, rows_bf16_bench = _split( - [r for r in _perf_results if r["state"] == "bf16"] - ) + ("FlyDSL", "FlyDSL_vk(us)", 8), + ("Tri_vk", "Triton_vk(us)", 8), + ("Tri_orig_opt", "Triton_origin_opt(us)", 12), + ("vLLM", "vLLM_vk(us)", 8), + ("fly/vk", "flydsl_vs_vk", 7), + ("fly/o_opt", "flydsl_vs_origin_opt", 9), + ("fly/vllm", "flydsl_vs_vllm", 8), + ] - lines = ["", border] - lines.append( - "K5 Prefill Performance Summary " - "(K5 device kernel time only, via torch.profiler)" - ) - lines.append( - " Triton K5 references always use fp32 SSM state; only FlyDSL's " - "SSM-state dtype changes between the sub-tables below." - ) - lines.append(border) - - def _emit_subtable(title, rows, cols, header, sep): - if not rows: - return - lines.append("") - lines.append(title) - lines.append(sep) - lines.append(header) - lines.append(sep) - for row in rows: - lines.append(_fmt_row(row, cols)) - lines.append(sep) - - _emit_subtable( - "[FlyDSL SSM state = fp32] -- main groups", - rows_fp32_main, - main_cols, - main_header, - main_sep, - ) - _emit_subtable( - "[FlyDSL SSM state = fp32] -- bench333 trace shapes", - rows_fp32_bench, - bench_cols, - bench_header, - bench_sep, - ) - _emit_subtable( - "[FlyDSL SSM state = bf16] -- main groups", - rows_bf16_main, - main_cols, - main_header, - main_sep, - ) - _emit_subtable( - "[FlyDSL SSM state = bf16] -- bench333 trace shapes", - rows_bf16_bench, - bench_cols, - bench_header, - bench_sep, - ) + def _fmt_cell(val, key, width): + if isinstance(val, bool): + return ("Y" if val else "N").rjust(width) + if isinstance(val, float): + if val != val: # NaN (vLLM column when vllm not installed) + return "-".rjust(width) + return (f"{val:.2f}x" if "_vs_" in key else f"{val:.1f}").rjust(width) + return str(val).rjust(width) + + header = " | ".join(display.rjust(w) for display, _, w in cols) + sep = "-+-".join("-" * w for _, _, w in cols) + border = "=" * len(header) + + lines = [ + "", + border, + "K5 Prefill Performance Summary (K5 device kernel time only, via torch.profiler)", + border, + "", + sep, + header, + sep, + ] + for row in _perf_results: + lines.append(" | ".join(_fmt_cell(row[k], k, w) for _, k, w in cols)) + lines.append(sep) lines.append("") print("\n".join(lines))