Describe the feature
We (ZML) are trying to write a production paged-attention kernel in NKI for Neuron. The kernel needs to be correct for arbitrary physical page order, and fast when the compiled batch or sequence capacity is larger than the live work.
I am filing this in the main AWS Neuron SDK repository rather than a narrower component repository because the blocker appears to span several layers.
We looked at vLLM Neuron first, especially release 0.5.1, because vLLM is the natural place to look for PagedAttention. That release is compatible with vLLM 0.16.0 and Neuron SDK 2.30.0, but its README says the NxD Inference path uses a contiguous K/V cache layout instead of PagedAttention support, and also says chunked prefill is not supported. So, from what we saw, there is not a current vLLM Neuron paged-attention kernel design that we can simply copy.
What we need is a supported NKI/compiler pattern for this architecture.
Contract
Inputs:
q: [num_tokens, num_kv_heads, heads_per_kv, head_dim]
k_cache: [num_pages, page_size, num_kv_heads, head_dim]
v_cache: [num_pages, page_size, num_kv_heads, head_dim]
block_table: [batch_capacity, max_pages_per_sequence]
seq_lens: [batch_capacity]
query_start_len: [batch_capacity + 1]
out: same logical shape as q
Semantics:
block_table[seq, logical_page] maps logical pages to physical pages.
- Physical pages may be permuted: the kernel cannot assume contiguous physical pages.
- Inactive pages must be safe with
oob_mode=skip or equivalent semantics.
seq_lens is the live sequence length, not necessarily the compiled maximum.
query_start_len gives compact/ragged query row intervals.
- GQA must work, for example 32 Q heads, 8 KV heads, head dim 128.
- Page sizes 16, 32 and 64 are important for serving efficiency.
The important performance rule is: stream only live K/V pages that can affect the current query tile. Masks are still needed for tails and causal edges, but masks should not be the primary way to skip inactive rows or inactive K/V tiles.
For decode, the useful work is:
active rows * kv heads * heads_per_kv * live K/V tiles
not:
compiled batch capacity * kv heads * heads_per_kv * compiled max K/V tiles
For prefill/mixed attention, the useful K/V prefix for a Q tile follows the ordinary causal attention rule. It is not a new public paged-attention semantic; it is just the implementation bound for "how many key tokens are visible to this Q tile".
For a sequence:
q_start = query_start_len[seq_idx]
q_end = query_start_len[seq_idx + 1]
q_len = q_end - q_start
context_len = seq_lens[seq_idx] - q_len
For a Q tile:
tile_query_start = current_q_start - q_start
tile_query_rows = min(Q_TILE, q_len - tile_query_start)
tile_query_end = tile_query_start + tile_query_rows
visible_k_tokens_for_tile = min(seq_lens[seq_idx], context_len + tile_query_end)
The K/V streaming loop should cover only [0, visible_k_tokens_for_tile)
Target architecture
The target architecture is a split decode and prefill/mixed implementation over the same paged cache contract. Both paths keep the online-softmax recurrence on sbuf. Both paths load K/V through block_table, preserve arbitrary physical page order, and use -1 inactive-page sentinels with skip semantics.
Decode should be the simpler path. It has one query row per active sequence, so the runtime-controlled loop only needs to cover live K/V segments:
for kv_head in kv_heads:
for seq_idx in compiled_decode_rows:
active_seq_len = where(seq_idx_is_active, seq_lens[seq_idx], 0)
segment_count = ceil(active_seq_len / K_SEG)
for _ in runtime_range(segment_count):
pages = block_table[seq_idx, current_pages]
pages = where(page_is_live, pages, -1)
k = load_k_pages(k_cache, pages, oob_mode=skip)
qk = q @ k
update_online_softmax_state(qk)
v = load_v_pages(v_cache, pages, oob_mode=skip)
update_output(v)
This is not just masking inactive pages after the fact. Inactive rows should produce zero K/V segments, and active rows should stream only their live K/V segments.
Prefill/mixed attention needs a runtime Q-tile loop for compile stability on long prompts, and inside each Q tile it needs a runtime K/V streaming loop bounded by the visible K/V prefix:
for kv_head in kv_heads:
for seq_idx in batch_capacity:
current_q_start = query_start_len[seq_idx]
q_end = query_start_len[seq_idx + 1]
while current_q_start < q_end:
q_tile = load_compact_q_tile(current_q_start)
visible_k_tokens_for_tile = compute_visible_k_prefix(...)
segment_count = ceil(visible_k_tokens_for_tile / K_SEG)
for _ in runtime_range(segment_count):
pages = block_table[seq_idx, current_pages]
pages = where(page_is_live, pages, -1)
k = load_k_pages(k_cache, pages, oob_mode=skip)
qk = q_tile @ k
qk = apply_causal_and_tail_masks(qk)
update_online_softmax_state(qk)
v = load_v_pages(v_cache, pages, oob_mode=skip)
update_output(v)
current_q_start += Q_TILE
This is the control-flow shape we need for ragged paged attention: the number of Q tiles and the number of K/V segments are both runtime quantities. The kernel should be able to advance Q tiles at runtime, and for each Q tile stream only the K/V segments that are live and visible to that tile.
Current blocker
It is intentionally not a full attention kernel. It isolates the compiler issue we hit while trying to express the target architecture:
dynamic outer loop over Q tiles
dynamic inner loop over visible K/V segments
The script first invokes the NKI kernel normally to generate the compiler artifacts. In the installed NKI path, that normal invocation succeeds because one NKI standalone compile path uses a hidden internal backend option for nested dynamic loops. That option is not shown in public neuronx-cc compile --help output and is not a documented NKI/kernel contract. In our framework we compile to BIR (https://github.com/zml/zml/pull/528/changes#diff-689ea608674648377c0e9cf5943ebe6591e2c2f01f5b6f846a1b478fe2b7676a) and wrap that into AwsNeuronCustomNativeKernel through StableHLO for neuronxla 3.0 . The script then replays the generated penguin.py through neuronx-cc without that hidden internal option. The replay failure is the compiler behavior this issue is about: a dynamic loop inside another dynamic loop:
repro_nested_dynamic_loops.py:
import os
import shutil
import subprocess
from pathlib import Path
os.environ.setdefault("NEURON_CC_FLAGS", "--target inf2 --lnc 1 --verbose=info")
os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", "inf2")
os.environ.setdefault("NKI_ARTIFACTS_DIR", "/tmp/nki_nested_dynamic_loop_repro")
os.environ.setdefault("NKI_VERBOSE_COMPILE", "1")
import numpy as np
import nki
import nki.isa as nisa
import nki.language as nl
@nki.jit
def nested_dynamic_loop_repro(loop_counts):
counts = nl.ndarray((2, 1), dtype=loop_counts.dtype, buffer=nl.sbuf)
one = nl.full((1, 1), 1, dtype=loop_counts.dtype, buffer=nl.sbuf)
zero = nl.full((1, 1), 0, dtype=loop_counts.dtype, buffer=nl.sbuf)
zero_f32 = nl.full((1, 1), 0.0, dtype=nl.float32, buffer=nl.sbuf)
outer_remaining = nl.ndarray((1, 1), dtype=loop_counts.dtype, buffer=nl.sbuf)
outer_remaining_f32 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf)
total = nl.full((1, 1), 0, dtype=loop_counts.dtype, buffer=nl.sbuf)
next_total = nl.ndarray((1, 1), dtype=loop_counts.dtype, buffer=nl.sbuf)
nisa.dma_copy(dst=counts, src=loop_counts[0:2, 0:1])
inner_loop = nisa.register_alloc(None)
nisa.register_load(dst=inner_loop, src=counts[1:2, 0:1])
nisa.tensor_copy(dst=outer_remaining, src=counts[0:1, 0:1])
nisa.tensor_copy(dst=outer_remaining_f32, src=outer_remaining)
outer_alive = nl.where(
nl.greater(outer_remaining_f32, zero_f32),
one,
zero,
dtype=loop_counts.dtype,
)
outer_loop = nisa.register_alloc(None)
nisa.register_load(dst=outer_loop, src=outer_alive)
while outer_loop:
for _ in nl.dynamic_range(inner_loop):
nisa.tensor_tensor(dst=next_total, data1=total, data2=one, op=nl.add)
nisa.tensor_copy(dst=total, src=next_total)
nisa.tensor_tensor(
dst=outer_remaining,
data1=outer_remaining,
data2=one,
op=nl.subtract,
)
nisa.tensor_copy(dst=outer_remaining_f32, src=outer_remaining)
outer_alive = nl.where(
nl.greater(outer_remaining_f32, zero_f32),
one,
zero,
dtype=loop_counts.dtype,
)
nisa.register_load(dst=outer_loop, src=outer_alive)
out = nl.ndarray((1, 1), dtype=loop_counts.dtype, buffer=nl.shared_hbm)
nisa.dma_copy(dst=out[0:1, 0:1], src=total)
return out
if __name__ == "__main__":
artifacts = Path(os.environ["NKI_ARTIFACTS_DIR"])
replay = Path("/tmp/nki_nested_dynamic_loop_replay")
shutil.rmtree(artifacts, ignore_errors=True)
shutil.rmtree(replay, ignore_errors=True)
print(nested_dynamic_loop_repro(np.array([[2], [3]], dtype=np.int32)))
replay.mkdir()
shutil.copy(artifacts / "penguin.py", replay / "penguin.py")
for path in artifacts.glob("__main__*.json"):
shutil.copy(path, replay / path.name)
subprocess.run(
[
"neuronx-cc",
"compile",
"--framework",
"XLA",
"penguin.py",
"--pipeline",
"compile",
"SaveTemps",
"--target",
"inf2",
"--lnc",
"1",
"--output",
str(replay / "kernel.neff"),
],
cwd=replay,
check=True,
)
The replay fails with:
[INTERNAL_ERROR] [NCC_ICFG018] Loop structures require exactly one internal back-edge as we currently do not support CONTINUE statements or nested loops.
and:
with internal back-edges:
Block1_LoopBody_1_Resume_2 -> Block1_LoopBody_1,
Block1_LoopBody_1_LoopBody_2 -> Block1_LoopBody_1_LoopBody_2.
Use Case
The use case is high-throughput LLM serving on Inf2/Trn2/Trn3 with a paged KV cache. Requests are ragged: some batch slots are inactive, decode rows have different live sequence lengths, and prefill/mixed batches may contain prompts with different query lengths and context lengths.
Proposed Solution
Concretely, guidance or compiler support is needed for:
- nested dynamic loops for Q tiles and visible K/V segments;
- bounded runtime K/V streaming loops with loop-carried online-softmax state;
- active-row decode without work for inactive batch slots;
- prefill/mixed attention that streams only the K/V prefix visible to each Q tile;
- support inf2/trn2/trn3 instances
Other Information
Compiler environment for the reproducer:
target: inf2.8xlarge
NeuronX compiler: 2.25.3371.0+f524f7f8
NKI: 0.4.0+25940409122.gd30719f9
Acknowledgements
Describe the feature
We (ZML) are trying to write a production paged-attention kernel in NKI for Neuron. The kernel needs to be correct for arbitrary physical page order, and fast when the compiled batch or sequence capacity is larger than the live work.
I am filing this in the main AWS Neuron SDK repository rather than a narrower component repository because the blocker appears to span several layers.
We looked at vLLM Neuron first, especially release
0.5.1, because vLLM is the natural place to look for PagedAttention. That release is compatible with vLLM0.16.0and Neuron SDK2.30.0, but its README says the NxD Inference path uses a contiguous K/V cache layout instead of PagedAttention support, and also says chunked prefill is not supported. So, from what we saw, there is not a current vLLM Neuron paged-attention kernel design that we can simply copy.What we need is a supported NKI/compiler pattern for this architecture.
Contract
Inputs:
Semantics:
block_table[seq, logical_page]maps logical pages to physical pages.oob_mode=skipor equivalent semantics.seq_lensis the live sequence length, not necessarily the compiled maximum.query_start_lengives compact/ragged query row intervals.The important performance rule is: stream only live K/V pages that can affect the current query tile. Masks are still needed for tails and causal edges, but masks should not be the primary way to skip inactive rows or inactive K/V tiles.
For decode, the useful work is:
not:
For prefill/mixed attention, the useful K/V prefix for a Q tile follows the ordinary causal attention rule. It is not a new public paged-attention semantic; it is just the implementation bound for "how many key tokens are visible to this Q tile".
For a sequence:
For a Q tile:
The K/V streaming loop should cover only
[0, visible_k_tokens_for_tile)Target architecture
The target architecture is a split decode and prefill/mixed implementation over the same paged cache contract. Both paths keep the online-softmax recurrence on sbuf. Both paths load K/V through
block_table, preserve arbitrary physical page order, and use-1inactive-page sentinels with skip semantics.Decode should be the simpler path. It has one query row per active sequence, so the runtime-controlled loop only needs to cover live K/V segments:
This is not just masking inactive pages after the fact. Inactive rows should produce zero K/V segments, and active rows should stream only their live K/V segments.
Prefill/mixed attention needs a runtime Q-tile loop for compile stability on long prompts, and inside each Q tile it needs a runtime K/V streaming loop bounded by the visible K/V prefix:
This is the control-flow shape we need for ragged paged attention: the number of Q tiles and the number of K/V segments are both runtime quantities. The kernel should be able to advance Q tiles at runtime, and for each Q tile stream only the K/V segments that are live and visible to that tile.
Current blocker
It is intentionally not a full attention kernel. It isolates the compiler issue we hit while trying to express the target architecture:
The script first invokes the NKI kernel normally to generate the compiler artifacts. In the installed NKI path, that normal invocation succeeds because one NKI standalone compile path uses a hidden internal backend option for nested dynamic loops. That option is not shown in public
neuronx-cc compile --helpoutput and is not a documented NKI/kernel contract. In our framework we compile to BIR (https://github.com/zml/zml/pull/528/changes#diff-689ea608674648377c0e9cf5943ebe6591e2c2f01f5b6f846a1b478fe2b7676a) and wrap that intoAwsNeuronCustomNativeKernelthrough StableHLO for neuronxla 3.0 . The script then replays the generatedpenguin.pythroughneuronx-ccwithout that hidden internal option. The replay failure is the compiler behavior this issue is about: a dynamic loop inside another dynamic loop:repro_nested_dynamic_loops.py:The replay fails with:
and:
Use Case
The use case is high-throughput LLM serving on Inf2/Trn2/Trn3 with a paged KV cache. Requests are ragged: some batch slots are inactive, decode rows have different live sequence lengths, and prefill/mixed batches may contain prompts with different query lengths and context lengths.
Proposed Solution
Concretely, guidance or compiler support is needed for:
Other Information
Compiler environment for the reproducer:
Acknowledgements