Skip to content

Support an efficient NKI implementation pattern for ragged paged attention #1346

Description

@hugomano

Describe the feature

We (ZML) are trying to write a production paged-attention kernel in NKI for Neuron. The kernel needs to be correct for arbitrary physical page order, and fast when the compiled batch or sequence capacity is larger than the live work.

I am filing this in the main AWS Neuron SDK repository rather than a narrower component repository because the blocker appears to span several layers.

We looked at vLLM Neuron first, especially release 0.5.1, because vLLM is the natural place to look for PagedAttention. That release is compatible with vLLM 0.16.0 and Neuron SDK 2.30.0, but its README says the NxD Inference path uses a contiguous K/V cache layout instead of PagedAttention support, and also says chunked prefill is not supported. So, from what we saw, there is not a current vLLM Neuron paged-attention kernel design that we can simply copy.

What we need is a supported NKI/compiler pattern for this architecture.

Contract

Inputs:

q: [num_tokens, num_kv_heads, heads_per_kv, head_dim]
k_cache:  [num_pages, page_size, num_kv_heads, head_dim]
v_cache:  [num_pages, page_size, num_kv_heads, head_dim]
block_table:  [batch_capacity, max_pages_per_sequence]
seq_lens:  [batch_capacity]
query_start_len: [batch_capacity + 1]
out:  same logical shape as q

Semantics:

  • block_table[seq, logical_page] maps logical pages to physical pages.
  • Physical pages may be permuted: the kernel cannot assume contiguous physical pages.
  • Inactive pages must be safe with oob_mode=skip or equivalent semantics.
  • seq_lens is the live sequence length, not necessarily the compiled maximum.
  • query_start_len gives compact/ragged query row intervals.
  • GQA must work, for example 32 Q heads, 8 KV heads, head dim 128.
  • Page sizes 16, 32 and 64 are important for serving efficiency.

The important performance rule is: stream only live K/V pages that can affect the current query tile. Masks are still needed for tails and causal edges, but masks should not be the primary way to skip inactive rows or inactive K/V tiles.

For decode, the useful work is:

active rows * kv heads * heads_per_kv * live K/V tiles

not:

compiled batch capacity * kv heads * heads_per_kv * compiled max K/V tiles

For prefill/mixed attention, the useful K/V prefix for a Q tile follows the ordinary causal attention rule. It is not a new public paged-attention semantic; it is just the implementation bound for "how many key tokens are visible to this Q tile".

For a sequence:

q_start = query_start_len[seq_idx]
q_end = query_start_len[seq_idx + 1]
q_len = q_end - q_start
context_len = seq_lens[seq_idx] - q_len

For a Q tile:

tile_query_start = current_q_start - q_start
tile_query_rows = min(Q_TILE, q_len - tile_query_start)
tile_query_end = tile_query_start + tile_query_rows
visible_k_tokens_for_tile = min(seq_lens[seq_idx], context_len + tile_query_end)

The K/V streaming loop should cover only [0, visible_k_tokens_for_tile)

Target architecture

The target architecture is a split decode and prefill/mixed implementation over the same paged cache contract. Both paths keep the online-softmax recurrence on sbuf. Both paths load K/V through block_table, preserve arbitrary physical page order, and use -1 inactive-page sentinels with skip semantics.

Decode should be the simpler path. It has one query row per active sequence, so the runtime-controlled loop only needs to cover live K/V segments:

for kv_head in kv_heads:
    for seq_idx in compiled_decode_rows:
        active_seq_len = where(seq_idx_is_active, seq_lens[seq_idx], 0)
        segment_count = ceil(active_seq_len / K_SEG)
        for _ in runtime_range(segment_count):
            pages = block_table[seq_idx, current_pages]
            pages = where(page_is_live, pages, -1)
            k = load_k_pages(k_cache, pages, oob_mode=skip)
            qk = q @ k
            update_online_softmax_state(qk)
            v = load_v_pages(v_cache, pages, oob_mode=skip)
            update_output(v)

This is not just masking inactive pages after the fact. Inactive rows should produce zero K/V segments, and active rows should stream only their live K/V segments.

Prefill/mixed attention needs a runtime Q-tile loop for compile stability on long prompts, and inside each Q tile it needs a runtime K/V streaming loop bounded by the visible K/V prefix:

for kv_head in kv_heads:
    for seq_idx in batch_capacity:
        current_q_start = query_start_len[seq_idx]
        q_end = query_start_len[seq_idx + 1]

        while current_q_start < q_end:
            q_tile = load_compact_q_tile(current_q_start)
            visible_k_tokens_for_tile = compute_visible_k_prefix(...)
            segment_count = ceil(visible_k_tokens_for_tile / K_SEG)

            for _ in runtime_range(segment_count):
                pages = block_table[seq_idx, current_pages]
                pages = where(page_is_live, pages, -1)
                k = load_k_pages(k_cache, pages, oob_mode=skip)
                qk = q_tile @ k
                qk = apply_causal_and_tail_masks(qk)
                update_online_softmax_state(qk)
                v = load_v_pages(v_cache, pages, oob_mode=skip)
                update_output(v)

            current_q_start += Q_TILE

This is the control-flow shape we need for ragged paged attention: the number of Q tiles and the number of K/V segments are both runtime quantities. The kernel should be able to advance Q tiles at runtime, and for each Q tile stream only the K/V segments that are live and visible to that tile.

Current blocker

It is intentionally not a full attention kernel. It isolates the compiler issue we hit while trying to express the target architecture:

dynamic outer loop over Q tiles
    dynamic inner loop over visible K/V segments

The script first invokes the NKI kernel normally to generate the compiler artifacts. In the installed NKI path, that normal invocation succeeds because one NKI standalone compile path uses a hidden internal backend option for nested dynamic loops. That option is not shown in public neuronx-cc compile --help output and is not a documented NKI/kernel contract. In our framework we compile to BIR (https://github.com/zml/zml/pull/528/changes#diff-689ea608674648377c0e9cf5943ebe6591e2c2f01f5b6f846a1b478fe2b7676a) and wrap that into AwsNeuronCustomNativeKernel through StableHLO for neuronxla 3.0 . The script then replays the generated penguin.py through neuronx-cc without that hidden internal option. The replay failure is the compiler behavior this issue is about: a dynamic loop inside another dynamic loop:

repro_nested_dynamic_loops.py:

import os
import shutil
import subprocess
from pathlib import Path

os.environ.setdefault("NEURON_CC_FLAGS", "--target inf2 --lnc 1 --verbose=info")
os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", "inf2")
os.environ.setdefault("NKI_ARTIFACTS_DIR", "/tmp/nki_nested_dynamic_loop_repro")
os.environ.setdefault("NKI_VERBOSE_COMPILE", "1")

import numpy as np
import nki
import nki.isa as nisa
import nki.language as nl


@nki.jit
def nested_dynamic_loop_repro(loop_counts):
    counts = nl.ndarray((2, 1), dtype=loop_counts.dtype, buffer=nl.sbuf)
    one = nl.full((1, 1), 1, dtype=loop_counts.dtype, buffer=nl.sbuf)
    zero = nl.full((1, 1), 0, dtype=loop_counts.dtype, buffer=nl.sbuf)
    zero_f32 = nl.full((1, 1), 0.0, dtype=nl.float32, buffer=nl.sbuf)
    outer_remaining = nl.ndarray((1, 1), dtype=loop_counts.dtype, buffer=nl.sbuf)
    outer_remaining_f32 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf)
    total = nl.full((1, 1), 0, dtype=loop_counts.dtype, buffer=nl.sbuf)
    next_total = nl.ndarray((1, 1), dtype=loop_counts.dtype, buffer=nl.sbuf)

    nisa.dma_copy(dst=counts, src=loop_counts[0:2, 0:1])
    inner_loop = nisa.register_alloc(None)
    nisa.register_load(dst=inner_loop, src=counts[1:2, 0:1])
    nisa.tensor_copy(dst=outer_remaining, src=counts[0:1, 0:1])
    nisa.tensor_copy(dst=outer_remaining_f32, src=outer_remaining)
    outer_alive = nl.where(
        nl.greater(outer_remaining_f32, zero_f32),
        one,
        zero,
        dtype=loop_counts.dtype,
    )
    outer_loop = nisa.register_alloc(None)
    nisa.register_load(dst=outer_loop, src=outer_alive)

    while outer_loop:
        for _ in nl.dynamic_range(inner_loop):
            nisa.tensor_tensor(dst=next_total, data1=total, data2=one, op=nl.add)
            nisa.tensor_copy(dst=total, src=next_total)
        nisa.tensor_tensor(
            dst=outer_remaining,
            data1=outer_remaining,
            data2=one,
            op=nl.subtract,
        )
        nisa.tensor_copy(dst=outer_remaining_f32, src=outer_remaining)
        outer_alive = nl.where(
            nl.greater(outer_remaining_f32, zero_f32),
            one,
            zero,
            dtype=loop_counts.dtype,
        )
        nisa.register_load(dst=outer_loop, src=outer_alive)

    out = nl.ndarray((1, 1), dtype=loop_counts.dtype, buffer=nl.shared_hbm)
    nisa.dma_copy(dst=out[0:1, 0:1], src=total)
    return out


if __name__ == "__main__":
    artifacts = Path(os.environ["NKI_ARTIFACTS_DIR"])
    replay = Path("/tmp/nki_nested_dynamic_loop_replay")
    shutil.rmtree(artifacts, ignore_errors=True)
    shutil.rmtree(replay, ignore_errors=True)

    print(nested_dynamic_loop_repro(np.array([[2], [3]], dtype=np.int32)))

    replay.mkdir()
    shutil.copy(artifacts / "penguin.py", replay / "penguin.py")
    for path in artifacts.glob("__main__*.json"):
        shutil.copy(path, replay / path.name)

    subprocess.run(
        [
            "neuronx-cc",
            "compile",
            "--framework",
            "XLA",
            "penguin.py",
            "--pipeline",
            "compile",
            "SaveTemps",
            "--target",
            "inf2",
            "--lnc",
            "1",
            "--output",
            str(replay / "kernel.neff"),
        ],
        cwd=replay,
        check=True,
    )

The replay fails with:

[INTERNAL_ERROR] [NCC_ICFG018] Loop structures require exactly one internal back-edge as we currently do not support CONTINUE statements or nested loops.

and:

with internal back-edges:
Block1_LoopBody_1_Resume_2 -> Block1_LoopBody_1,
Block1_LoopBody_1_LoopBody_2 -> Block1_LoopBody_1_LoopBody_2.

Use Case

The use case is high-throughput LLM serving on Inf2/Trn2/Trn3 with a paged KV cache. Requests are ragged: some batch slots are inactive, decode rows have different live sequence lengths, and prefill/mixed batches may contain prompts with different query lengths and context lengths.

Proposed Solution

Concretely, guidance or compiler support is needed for:

  • nested dynamic loops for Q tiles and visible K/V segments;
  • bounded runtime K/V streaming loops with loop-carried online-softmax state;
  • active-row decode without work for inactive batch slots;
  • prefill/mixed attention that streams only the K/V prefix visible to each Q tile;
  • support inf2/trn2/trn3 instances

Other Information

Compiler environment for the reproducer:

target: inf2.8xlarge
NeuronX compiler: 2.25.3371.0+f524f7f8
NKI: 0.4.0+25940409122.gd30719f9

Acknowledgements

  • I may be able to implement this feature request

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions