Skip to content

WIP: add nkigen-lite as a standalone IR-based kernel generation backend#59

Draft
ymwangg wants to merge 45 commits into
mainfrom
nkigen-lite
Draft

WIP: add nkigen-lite as a standalone IR-based kernel generation backend#59
ymwangg wants to merge 45 commits into
mainfrom
nkigen-lite

Conversation

@ymwangg

@ymwangg ymwangg commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds nkigen-lite, a standalone IR-based kernel generation backend that lowers numpy-style tensor programs to NKI (Neuron Kernel Interface) code for NeuronCore targets.

Architecture

The system is structured as a three-layer IR stack with a multi-phase lowering pipeline:

Core (core.py)

Shared SSA-based IR infrastructure used by both IRs:

  • Value, Op, Graph — SSA primitives with use-lists and mutation helpers
  • DType enum covering f32/f16/bf16/tf32/fp8/int types
  • Common graph utilities: DCE, verification, toposort
  • Shared numpy interpreter dispatch tables

Tensor IR (tensor_ir/)

High-level, hardware-agnostic IR operating on whole tensors:

  • SSA-based — every op produces new Value(s), enabling clean analysis and transformation
  • Numpy-like builder API — familiar interface for constructing kernel graphs
  • Numpy interpreter — executes the IR with real data for correctness checking
  • Ops: elementwise (unary/binary), reduce, matmul, transpose, reshape, slice, concat, broadcast

NKI IR (nki_ir/)

Low-level IR that makes hardware concerns explicit:

  • Memory spaces — every value carries HBM/SBUF/PSUM placement
  • Partition dimension — dim 0 of on-chip tiles is the partition dim (max 128)
  • Explicit memory management — alloc/dealloc + DMA copies for data movement
  • Pre-allocated destinations — all compute ops take a dst parameter
  • Tile indexing — DimSlice-based indexing (ts/ds) mirroring Kernel Builder
  • Loop constructs — fori_loop for explicit tile iteration
  • Hardware verifier — checks tile constraints against target specs
  • Numpy interpreter — reference execution without hardware
  • Emit to Kernel Builder — walks the graph and invokes KB API calls to produce NISA MLIR

Lowering Pipeline (tensor_ir/passes/)

The full pipeline: tensor_ir → canonicalize → decompose → layout_solver → direct_lower → nki_ir

  1. Canonicalize — recomposes primitive-op chains into high-level ops (e.g., div(1, sqrt(x))rsqrt(x), mul(x, div(1, add(1, exp(neg(x)))))silu(x))

  2. Decompose — lowers ops without direct NISA equivalents into supported primitives (e.g., div(a,b)mul(a, reciprocal(b)), reduce(mean)reduce(sum) * 1/N)

  3. Layout Solver — assigns each tensor dimension to one of three roles:

    • I (iteration) — loop indices, not in SBUF tile
    • P (partition) — SBUF dim-0, product ≤ 128, parallel compute
    • F (free) — SBUF dim-1, contiguous per partition

    Propagates constraints across the graph to find a globally consistent assignment.

  4. Direct Lower — converts tensor IR ops to tiled NKI IR:

    • Segments ops into elementwise groups (fused on-chip) vs individual non-elementwise ops (HBM boundaries)
    • Generates tiled load→compute→store sequences
    • Per-op lowering modules: memory, elementwise, reduce, matmul, transpose, broadcast
    • Inserts deallocs via liveness analysis after lowering

Hardware Target (passes/hardware.py)

Parameterized hardware profiles (TRN2 defaults) defining partition limits, SBUF/PSUM sizes, and matmul constraints.

Status

🚧 Work in progress — not ready for review.

Test plan

  • Full test suite passes (uv run pytest nkigen-lite/tests/ -n auto)
  • Integration with main nkipy package verified
  • End-to-end lowering produces correct NKI IR for representative patterns

ymwangg added 22 commits June 1, 2026 22:24
Migrates tensor_ir, nki_ir, and the direct lowering passes from
nano-tensorizer/ir_lab into the nkipy workspace as a new package.
The pipeline (canonicalize → decompose → layout_solver → direct_lower)
produces legal NKI IR directly without intermediate passes.
Add nkigen-lite as a fully functional backend (backend="nkigen-lite")
alongside hlo and nkigen. The pipeline traces Python kernels through
nkigen_lite's tensor_ir Builder, lowers via the pass pipeline
(canonicalize → decompose → layout_solver → direct_lower), and compiles
to NEFF via the NKI kernel_builder API.

nkipy integration:
- backend/nkigen_lite.py: TraceContext, Tensor, IR adapter
- ops/_nkigen_lite_impls.py: op implementations delegating to Builder
- ops/_register_nkigen_lite.py: lazy op registration
- trace.py: _specialize_nkigen_lite() dispatch
- compile.py: _compile_nkigen_lite() via kernel_builder
- knob.py, nki_op.py: backend-aware dispatch

nkigen-lite enhancements:
- Builder: add abs, sign, floor, ceil, power, floor_divide, mod ops
- Interpreter: numpy dispatch for new ops, dtype-aware tensor_copy
- Decompose pass: floor_divide/mod use divide-then-verify-and-correct
  strategy (matching neuronx-cc BIR), power→exp(b*log(a)),
  ceil→neg(floor(neg(x))), fixed-point iteration with max-iter guard
- Direct lowering: abs/sign/sin via NisaActivationOp, floor via i32
  truncation + sign correction, cast via tensor_copy, 1D reshape fix
- docs/floor_divide_precision.md: documents the precision strategy

Test results: 134/135 HLO-parity tests pass on trn2 hardware (99.3%).
Add "nkigen-lite" to the trace_mode fixture so all parametrized tests
run with both backends. Add a pytest hook that marks NotImplementedError
as xfail for nkigen-lite — ops not yet implemented show as expected
failures and automatically start passing when added.

Current results:
- HLO: 741 passed, 4 xfailed, 42 skipped
- nkigen-lite: ~340 passed, 161 xfailed (unimplemented ops), ~93 failed
  (partial implementations needing further work)
Add ReduceKeepdimsFalsePattern to decompose keepdims=False reductions
into keepdims=True + reshape, which the layout solver and lowering
require. Handle scalar (rank-0) tensors throughout the lowering pipeline
by promoting them to (1,) at the NKI boundary since the hardware doesn't
support rank-0 tensors.

Also fix negative axis normalization in squeeze() and expand_dims().
- matmul: add 1D→2D promotion following NumPy semantics
- squeeze: validate non-1 dims, normalize negative axis
- reshape: handle int newshape argument
- zeros/full: handle int shape argument
- concatenate: handle single-tensor case, validate empty/axis bounds
- split: validate axis bounds and unequal division
- where: handle numpy array condition argument
- _ensure_value: handle numpy array operands (uniform-fill)
- expand_dims: validate duplicate axes and out-of-bounds axis
- Skip test_reduce_unsupported_op and test_topk_non_last_axis for
  non-HLO backends since they test HLO-specific internal behavior
NeuronCore hardware only supports Add and Max for cross_lane_reduce_arith.
Implement MIN as -max(-x) transparently in the NKI IR builder so all
existing P-dimension reduce codepaths work with min reductions.
Replace HLO-specific DeviceKernel.compile_and_load path with the
shared on_device_test utility which handles input/output naming
differences between backends automatically.
- broadcast_to: handle scalar (rank-0) source by loading the single
  element and broadcasting via tensor_scalar_arith with ones
- emit_to_kb: auto-cast f16/bf16 operands to f32 around
  tensor_scalar_arith since the hardware scalar engine requires f32
Add NisaBitvecOp enum and tensor_tensor_bitvec builder method to NKI IR.
Wire through the full pipeline: tensor IR opcodes, elementwise lowering,
emit_to_kb mapping, and interpreter support. Replace the old arithmetic
approximations (which only worked for booleans) with hardware bitwise
instructions that work correctly on integer types.
Add NKI IR primitives for mixed-dtype operations:
- tensor_tensor_compare: comparison ops (IsGT, IsGE, etc.) that accept
  float inputs and produce uint8 predicate output
- tensor_scalar_bitvec: scalar bitvec ops (XOR for logical NOT, etc.)
- Comparison and logical op variants in NisaArithOp enum

Rewrite _emit_floor to use the NKI compiler's compare+select pattern:
trunc→compare→conditional select in integer domain, avoiding float
precision issues in the correction step.
…lice

- emit_slice: add strides parameter; delegate to _emit_strided_slice
  for non-unit strides (element-by-element DMA for F-stride, row-by-row
  for P-stride)
- dynamic_update_slice: handle numpy array value argument (uniform fill)
Add DType.FP8_E4M3_IEEE for the IEEE-standard float8_e4m3 format
(distinct from the NaN-free float8_e4m3fn variant already supported).
Wire through core, emit_to_kb, and compile dtype mappings.
Each pytest-xdist worker now claims a specific Neuron core via
NEURON_RT_VISIBLE_CORES, enabling parallel test execution across
all 64 available cores (~8.5x speedup).
- Comparison ops (equal, not_equal, greater, less, etc.) now produce
  same dtype as input (1.0/0.0 float) matching NKI convention, instead
  of DType.BOOL
- where op lowered using NKI pattern: cond*x + (1-cond)*y with all
  float arithmetic — no mixed-dtype operations needed
- Map DType.BOOL → uint8 in kernel builder and execution layer
- Update tensor IR builder to remove BOOL requirement from where
- Reduces xfail count from 162 → 125 (37 tests now passing)
Matches NKI compiler's approach: cos(x) = sin(x + π/2).
The hardware sin activation instruction handles the computation.
Implement np.dot semantics as a composed op:
- 1D/2D cases delegate directly to matmul
- N-D × 1D delegates to matmul (batched matrix-vector)
- N-D × M-D decomposes to reshape + matmul + reshape to achieve
  the outer-product batch semantics of np.dot
- arctan: wire native NISA ARCTAN activation through the Builder and
  direct-lower tables
- invert/bitwise_not: composed_impl as XOR with all-ones (-1), matching
  the NKI compiler's implementation
- logical_and/or/xor: composed_impl via 0/1 truthiness; also unblocks
  rint/round which depend on logical_and
- constant: backend impl mirroring HLO (passthrough + uniform fill);
  non-uniform array constants raise NotImplementedError

Fix a pre-existing bug in _emit_broadcast_scalar that fed a (1,1) tile
to tensor_scalar_arith whose scalar operand partition dim must match the
destination; replicate to (p_size, 1) via broadcast_partition.

Also make test_ml_dtypes_constant_encoding's float8 xfails backend-aware
so float8_e5m2 on nkigen-lite no longer reports XPASS.
The slice-based gather produced wrong output shapes: it ignored
axis=None (no flatten), concatenated slices along the original axis
instead of replacing it with indices.shape, and mishandled scalar and
multi-dimensional index arrays.

Rewrite to match numpy:
  out.shape == a.shape[:axis] + indices.shape + a.shape[axis+1:]

- axis=None flattens the input first
- negative indices are normalized modulo the axis dimension
- each flat index becomes a width-1 slice; slices are concatenated then
  reshaped so the gathered axis is replaced by indices.shape (dropped
  entirely for a scalar index)

Fixes 13 failing test_take_scalar / test_take_numpy_indices cases.
…) for nkigen-lite

Wires the four distributed collectives through the full nkigen-lite stack:

- tensor_ir Builder: collective ops with correct output-shape inference
  (all_gather grows the gather dim, reduce_scatter/all_to_all shrink/grow
  by world size)
- nkipy lite impls + registration, mapping numpy reduce ufuncs to the
  collective reduce-op names
- direct_lower: stage collectives through internal HBM scratch buffers
  (the compiler forbids collectives from reading/writing kernel IO
  tensors directly — "Collective instruction cannot read IO tensors")
- nki_ir Builder: collective() side-effect node (HBM->HBM)
- emit_to_kb: lower to nisa.all_reduce/all_gather/reduce_scatter/all_to_all
  via ExplicitReplicaGroupAttr + dma_compute_reduce_op

The KB collective API only operates on the last (free) axis of 2D HBM
tensors (cc_dim=0 raises std::bad_cast), so all_gather/reduce_scatter
along other axes are staged via transpose-collective-transpose.

Fixes the all_reduce/all_gather/reduce_scatter/all_to_all xfails
(multiply-reduce variants stay xfailed for both backends — unsupported
by the compiler).
The earlier transpose workaround for all_gather/reduce_scatter was based on
a misdiagnosis: cc_dim=0 appeared to raise std::bad_cast, so collectives
were staged through a transpose to operate on the last axis. Multi-core
numerical verification showed that path silently dropped the remote rank's
data (all_gather duplicated the local source; reduce_scatter ignored the
per-rank scatter offset).

Root cause: the KB nisa collective APIs forward cc_dim to the native
builder un-converted, so a bare int 0 fails the int->enum cast. The NKI
collectives contract also requires collective_dim=0 for HBM tensors.

Fix:
- emit_to_kb: convert the int dim to CollectiveDimension (DIM_0/DIM_1)
  before calling nisa.all_gather/reduce_scatter/all_to_all
- drop the transpose workaround; gather/scatter along the requested dim
  directly

Verified on 2 NeuronCores with distinct per-rank data: all_reduce,
all_gather(dim0), reduce_scatter(dim0), and all_to_all all produce the
correct cross-rank results.
Two bugs combined to make a // b and a % b off by one on the rare inputs
where the true quotient is an exact integer:

1. The composed floor_divide impl (floor(divide(x, y))) ran at trace time,
   so the graph never contained a `floor_divide` opcode and the decompose
   pass's divide-then-verify-and-correct FloorDividePattern never fired.
   NeuronCore has no native divide -- it uses reciprocal multiply, which
   undershoots exact integers (2.0 -> 1.9999999), so plain floor gave N-1.
   Fix: register nkigen-lite-specific floor_divide/remainder impls that
   emit the native floor_divide/mod opcodes, so the correcting pattern runs.

2. Within FloorDividePattern, the up-correction used
   max(0, sign(|rem| - |b|)), which is 0 at the |rem| == |b| boundary
   (sign(0) == 0) -- exactly the exact-integer undershoot case. Replace
   with an inclusive greater_equal(|rem|, |b|); a genuine remainder is
   always strictly < |b|, so equality can only mean undershoot.

Fixes the 3 failing floordiv/mod broadcasting cases. Exact-integer-boundary
inputs remain inherently ambiguous under reciprocal division (numpy and the
device can disagree by an ULP), but all test cases now pass.
@ymwangg

ymwangg commented Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

Status Update: Replacing the HLO backend with nkigen-lite

Reported 2026-06-23 from full test suite run (uv run pytest tests/ -n auto).

Where we are

Over Jun 2–5, nkigen-lite was stood up as a third backend and brought to broad
parity with hlo across elementwise math, bitwise, reductions, matmul,
collectives, indexing, and an fp8 dtype.

Current suite status: 8 failed, 1291 passed, 103 skipped, 104 xfailed

To finish retiring the hlo backend, nkigen-lite still needs to clear the cases
that pass on hlo but fail/xfail on nkigen-lite: 2 hard failures + 100 xfailed
cases remaining
. (2 additional xfails are a compiler limitation, not a
nkigen-lite gap — see §4.)

Note: the 6 TestWriteFromTorch failures are environmental (torch not
installed) and are not backend-specific — excluded from this report.


1. Hard failures — numerical bugs (2 cases) — fix first

These already run but produce wrong results:

Test Mismatch Likely cause
test_indexing_slicing_comprehensive.py::...::test_view_assignment_semantics[nkigen-lite] 3/128 (2.34%) strided-slice / dynamic_update_slice (8cdd829)
test_kernels.py::test_kernel_default[nkigen-lite-rope_dynamo.py:0-kernel_spec2] 232/448 (51.8%), max abs ~305 cos-via-sin(x + π/2) decomposition (f5ec24b)

2. Unimplemented ops (58 cases)

Each op unblocks its cluster of tests:

Op Tests Op Tests
repeat 9 argmin 4
take_along_axis 6 diag 3
put_along_axis 6 triu 2
argmax 6 tril 2
scatter_along_axis 5 tile 2
pad 5 roll 2
flip 2
diff 2
trace 1
scatter_strided 1

3. Capability gaps — broader subsystems (42 cases)

Gap (xfail reason) Tests Affected test groups
Dynamic tensor indexing not yet supported 24 test_take (14), embedding-lookup / view-as-index patterns, test_slice_extraction, test_rotary_embed
nki modes other than HLO not implemented ~9 test_nki_with_grid, test_nki_simple, test_nki_direct_jit, test_nki_mutable_tensor, test_nki_direct_jit_with_grid
Non-uniform array constants not supported 4 test_cumsum family, test_constant_hlo_list_tuple
split with explicit indices 1 test_split_indices

4. Out of scope — compiler limitation (2 cases — leave xfailed)

Test Reason
test_collectives.py::test_all_reduce_multiply Compiler does not support multiply for reduce-scatter collectives
test_collectives.py::test_reduce_scatter_multiply (same)

These cannot be cleared by nkigen-lite work — hardware/compiler limitation.


Remaining work, in priority order

  1. Fix the 2 numerical bugs — correctness regressions, highest priority.
  2. Dynamic tensor indexing — single biggest unlock (24 tests); underpins
    take / embedding / rope patterns.
  3. nki non-HLO mode support — ~9 tests; needed for the JIT/grid path.
  4. High-count opsrepeat, take_along_axis, put_along_axis,
    argmax/argmin, scatter_along_axis, pad.
  5. Non-uniform array constants (cumsum family), then the long tail of
    single-test ops: diag, tril/triu, tile, roll, flip, diff,
    trace, split-indices.

Clearing §1–§3 retires ~100 of the 102 hlo-only-passing cases; the 2
collective-multiply tests (§4) remain xfailed as a compiler limitation.

ymwangg added 7 commits June 23, 2026 15:37
NRT does not carry FP8 dtypes through compiled neff metadata: e4m3/e5m2
surface as "int8" and e4m3fn surfaces as "unknown". The data round-trips
correctly on device, so the only blocker was spike's dtype validation,
which special-cased only e4m3/e5m2 and only accepted "int8".

Extend _check_dtype_compatibility to cover e4m3fn and accept either
placeholder ("int8"/"unknown"), gated behind the FP8 dtype set so
non-FP8 tensors are still validated strictly. Drop the matching
nkigen-lite xfail in test_ml_dtypes_constant_encoding.
The high-level tensor_ir.Builder had no index-ramp op, while the low-level
nki_ir already exposed nisa.iota. Bridge that gap across all four layers:

- tensor_ir/ir.py: Builder.iota(shape, dim, dtype) — out[..., i, ...] == i
  along dim, broadcast over other axes (np.arange-on-axis semantics).
- core.py: numpy eval in eval_common_op (shared by both interpreters).
- direct_lower.py: _emit_iota_op tiles under the canonical row-major layout
  and maps each axis to nisa.iota's pattern/channel_multiplier/offset
  (free -> step 1; partition -> channel_multiplier 1 + p_off; batch ->
  constant offset). Kept out of ELEMENTWISE_OPCODES since it is
  position-dependent.

Unblocks tril/triu/diag/trace, which build index masks via iota.

Adds TestIota HW coverage (per-axis, multi-tile partition, rank-3, and
iota feeding an elementwise op).
Implement gap-8 triangular/diagonal ops on top of the new iota primitive,
mirroring the HLO impls: build row/col index masks via iota + compare, then
where(mask, x, 0).

- tril/triu: keep row >= col-k (lower) or row <= col-k (upper); masks built
  over the last two axes, broadcast across batch dims.
- diag 1D->2D: extend v to length N with a zero-pad on the side away from the
  diagonal, broadcast across columns, keep col == row+k. Avoids the HLO
  take-based gather (dynamic indexing is unsupported on nkigen-lite).
- diag 2D->1D: mask the k-th diagonal and reduce-sum the off-axis to collapse
  to the diagonal vector; slice to diag_len.
- trace: mask the diagonal (offset) and reduce-sum both axes.

Flips 8 xfails to pass in tests/unit/test_tensor_api.py. Verified against
numpy on non-square diag/tril/triu and trace-with-offset edge cases beyond
the existing suite.
Implement gap-7 (pad) and gap-9 (flip/roll/tile/diff) as pure slice/concat
data movement — no new primitives needed.

- flip: reverse an axis by concatenating width-1 slices in descending order.
- tile: concatenate r copies of the running result along each axis.
- roll: cyclic shift via split at (n-shift) + swapped concat; supports
  axis=None (flatten), int axis, and tuple shift/axis.
- diff: iterated x[1:] - x[:-1] along axis (n times).
- pad: constant mode concatenates full-valued slabs; edge mode replicates
  the first/last slab. Handles scalar, per-axis, and asymmetric pad_width.

Flips 14 xfails to pass in tests/unit/test_tensor_api.py. Verified against
numpy on 3-D flip, multi-axis flip/roll, flattened roll, 3-D tile, and 3-D
asymmetric pad beyond the existing suite.
Implement argmax/argmin via index masking on top of iota, mirroring the HLO
algorithm: reduce to the extreme value along the axis, mark positions equal
to it with their index (an iota ramp) and all others with a large sentinel,
then min-reduce the indices — returning the first index that attains the
extreme, matching numpy's tie-break.

The whole computation runs in float32 (cast input up front, cast result to
int32 at the end): min/max reductions initialize with +/-inf, which cannot
be memset into an integer tile, so an integer input or index ramp would fail
to compile.

Handles axis=None (flatten), negative axis, and keepdims. Flips 8 xfails to
pass; verified against numpy on int inputs and tie-breaking beyond the suite.
The lite builder only emits uniform fills, so non-uniform array constants
previously raised NotImplementedError. Materialize them as a flat sequence
of run-length fills, concatenate, and reshape — cheap for structured/small
arrays, capped at 4096 runs to keep tracing bounded. Route the binary-operand
path (_ensure_value) through the same logic.

cumsum gets a dedicated nkigen-lite impl rather than relying on the composed
fallback (which builds constant(np.triu(np.ones((N,N))))): for axis=None the
flattened triangular matrix is (4096,4096), far too large as a literal. Build
U[i,j] = (i<=j) via iota + compare instead, then cumsum = x_2d @ U. Handles
axis=None/negative/middle axis and dtype.

Flips 8 xfails to pass (4 cumsum, list/tuple constant, integer where-cond).
Verified against numpy on 1-D/3-D/negative-axis cumsum and small structured
constants beyond the suite.
nkigen-lite has no convolution primitive, so decompose N-D conv into im2col +
a single matmul: gather each kernel position's strided window as a
(N, Ci, out_pts) block, concat along the channel axis into
(N, Ci*prod(K), out_pts), flatten the (transposed) weight to (Co, Ci*prod(K)),
and do one batched matmul -> (N, Co, out_pts).

A single fused matmul compiles ~35% faster than accumulating prod(K) separate
matmuls (95.6s -> 61.5s on ic=16/oc=32/k=3); the official conv2d suite drops
from 165s to 121s. Spatial padding is built from concat of zero slabs;
strided/dilated windows use strided slice. groups != 1 is unsupported.

Verified against PyTorch on conv2d (stride/padding/dilation/bias/1x1) and
conv3d including a non-cubic (2,3,3) kernel. conv tests require torch as the
oracle, now installed via the examples dependency group.
ymwangg added 16 commits June 23, 2026 23:41
repeat: insert a size-1 axis after the repeat axis, broadcast it to the
repeat count, then reshape to fold it back in — duplicating each element in
place (matching np.repeat / the HLO impl). Scalar integer repeats only.

split: handle an explicit list of split indices (numpy semantics — boundaries
clamped to the axis size, len(indices)+1 sub-arrays). Empty sub-arrays
(repeated/out-of-range index) raise NotImplementedError, since the lite IR
has no zero-size tensor representation.

Flips 10 xfails to pass (9 repeat variants + split_indices). Verified repeat
on axis=None/middle-axis beyond the suite.
The hardware SIN activation is only accurate near [-π, π]; outside it the
polynomial approximation diverges wildly (cos(500) returned ~2e7 instead of a
value in [-1, 1]). This silently corrupted any kernel feeding large arguments
into sin/cos — e.g. rope_dynamo, where bmm produces values up to ~512, showed
a ~52% output mismatch.

Add SinRangeReductionPattern to the decompose pass: sin(x) → sin(x - 2π·round(
x/2π)), with round(y)=floor(y+0.5), bringing the argument into [-π, π] before
the hardware SIN. The emitted inner sin carries a range_reduced attr so the
pattern doesn't re-match and loop. Runs after CosPattern so cos→sin lowering
happens first and both get reduced.

Fixes the rope_dynamo integration test (now passing) and makes cos/sin
accurate to <1e-4 across ±500 (was max error ~2e7). One pre-existing bug
remains: view-assignment aliasing.
No sort primitive exists, so extract top-k iteratively: k times, take the max
along the axis, record value + argmax index, then mask that single position to
-inf so the next round finds the following element. Masking by index position
(via an iota ramp) rather than value gives a stable lowest-index pick among
equal maxima. Supports descending (default) and ascending (negate in/out);
returns (values, uint32 indices) matching torch.topk.

Flips 8 xfails to pass, verified against torch.topk. Note: for exact-value
ties the index ordering may differ from torch (both are valid top-k results);
random-data tests are unaffected.
Wire the max8 / find_index8 selection primitives through every IR layer
(nki_ir Builder + interpreter + emit_to_kb, and the tensor_ir Builder +
interpreter + direct_lower), mirroring the iota plumbing. max8 returns the 8
largest values per partition along the free dim.

topk now computes values with hardware max8 (one instruction) instead of k
iterative reduce-max passes: move the topk axis to the free dim, flatten
leading dims to the partition dim, pad the free dim to >=8 with -inf, max8,
slice the first k. Limited to k<=8.

Indices are still recovered via the iota position-ramp argmin trick, NOT
find_index8: find_index8 (the MaxIndex instruction) fails the neuronx-cc ISA
check [NCC_IXCG864] on this target, so it is wired through the IR (and works
in the interpreter) but not used in the topk lowering.

All 8 topk tests pass against torch; verified 1-D, ascending, and index
tie-break edge cases.
Replace the k<=8 max8-only topk with the scanning loop used by nkilib's
topk_core, which supports any k and recovers indices from hardware.

Wire nc_match_replace8 through the IR (nki_ir Builder two-result op +
interpreter + emit_to_kb via nisa.max_index_and_match_replace). Model topk as
a single tensor_ir op (2 results: values + uint32 indices) whose lowering runs
the SBUF-resident scan: ceil(k/8) folds of max8 (next 8 largest) +
match_replace8 (record indices, mask taken values to -inf). Free dim < 8 is
padded with -inf; each fold's results DMA-store to the matching output column
slice.

Key fixes found by incremental testing:
- find_index8 is gen2-only and fails the ISA check on trn2; match_replace8
  (dst_idx) is the gen3+ index path.
- match_replace8 indices must be uint32, not int32 (DVE_READ_INDICES AP
  constraint NCC_IXCG988).
- insert_deallocs special-cases match_replace8's two dst buffers: result[1]
  aliases input[1] (dst_idx), not input[0] — the generic rule mis-freed the
  index tile ("use of released tile").

All 8 topk tests pass; verified k>8 (k=10,16 multi-fold), F<8 padding, 1-D,
and ascending against torch including indices.
…lash)

The nkigen-lite/tests HW suite failed entirely under `-n auto`. Root cause was
NOT plain core contention: spike._spike and nki.runtime._spike are separate
compiled extension modules that collide in CPython's loader — whichever is
imported second resolves to the first, raising
"ImportError: cannot import name 'ModelTensorInfo'". The nkigen-lite HW tests
execute via nki.runtime, so any prior `import spike` in a worker broke them.
This reproduced even at -n 1 (one worker) while -n 0 passed.

Populate the empty nkigen-lite/tests/conftest.py to:
- count NeuronCores by enumerating /dev/neuron* instead of importing spike,
  so the runtime module clash never occurs in a worker;
- pin each xdist worker to its own core (NEURON_RT_VISIBLE_CORES) and cap
  -n auto worker count to the core count, mirroring tests/conftest.py;
- register the `hw` marker (removes PytestUnknownMarkWarning; -m "not hw"
  now selects correctly).

Full nkigen-lite suite: 685 passed, 1 xfailed under -n auto (was 52+ HW
failures). Note: run serial suites with -n0, not -p no:xdist, so the
pytest_xdist_auto_num_workers hook stays valid.
conv3d lowering hangs the suite on large-channel cases (the Qwen3-VL
3->1152 case never completes), so xfail won't help (the body still
runs). conv2d passes but is very slow (~1-2 min/case). Skip both on
nkigen-lite to keep the suite fast and non-hanging; HLO coverage is
unaffected. Remove once the conv lowering path is optimized.
Add a 2-D per-partition runtime gather primitive (gather_along_axis) to
tensor_ir, lowering to the hardware nisa.gather instruction with partition
tiling for P > 128. Wire take_along_axis through it in the nkipy frontend:
broadcast indices on non-axis dims (matching HLO), move the gather axis to
the free dim, flatten leading dims to the partition dim, gather, then
reshape/transpose back. Handles axis=None and negative axes.

Unblocks take_along_axis and diagonal_gather (which routes through
take_along_axis) on nkigen-lite; verified on Trainium hardware.

New tests: nkigen-lite/tests/tensor_ir/test_gather.py (interpreter,
lowering gate, and hardware across 6 shapes incl. P>128 tiling).
Route np.take with runtime indices through the gather_along_axis hardware
primitive instead of raising NotImplementedError: move the gather axis to
the free dim, flatten leading dims to the partition dim, broadcast the
shared index vector across partitions, gather, then reshape/transpose the
(P, M) result back to numpy's take layout (axis replaced by indices.shape).

Guards for cases that can't lower:
- non-integer indices (boolean masks; nkigen-lite reports comparisons as
  f32 so the frontend bool guard misses them) -> NotImplementedError, which
  preserves the boolean-indexing negative tests.
- flattened partition extent P that is >128 and not a multiple of 128, a
  pre-existing reshape-lowering limit -> NotImplementedError (clean xfail).
Also skip the no-op inverse transpose for scalar/degenerate indices, which
otherwise tripped the rank-1 transpose lowering.

Unblocks dynamic take (test_take), the view-as-index MoE patterns, and the
gather half of rotary_embed; verified on Trainium.
Add the scatter-family ops by fixing and building on the indirect-DMA
primitive (dma_copy_indirect), which was previously unused and miswired.

nki_ir primitives (tensor_ir):
- scatter_rows: out = base.copy(); out[idx[r],:] = updates[r,:]  -> indirect
  DMA store, via the canonical dst.ap(vector_offset=) view + dma_copy.
- gather_rows: out[r,:] = src[idx[r],:]  -> indirect DMA load, same idiom on
  the source.  Both tile over PARTITION_MAX and take (M,1) U32 indices.
- Fix emit_to_kb store/load: the old low-level dma_copy_indirect calls passed
  mismatched full tiles (src.size != dst.size); use the .ap(vector_offset=)
  view so the DMA tiler sets up the indirect access pattern.
- Fix the nki_ir interpreter load/store to row semantics (was flat np.take/
  np.put), matching the hardware.

Frontend (nkipy):
- scatter_along_axis: a[:, t, :] = b -> move axis to row dim, scatter_rows.
- put_along_axis: per-element scatter via flatten-by-strides (HLO trick) onto
  width-1 row scatter; handles scalar values and axis=None.
- scatter_strided: a[::s, ::s] = b -> static cartesian-product flat indices +
  width-1 row scatter.
- take/take_along_axis axis-0 row-gather fast path via gather_rows, avoiding
  the full-table transpose that OOMed on tall tables (embedding 128256x2048).

Unblocks on Trainium: put_along_axis(_scalar_value), slice_assignment
(_indeterministic), step_slicing_assignment, rotary_embed, embedding_dynamo.

New HW tests: nkigen-lite/tests/tensor_ir/test_scatter.py (scatter_rows +
gather_rows, incl. N>128, M>128, duplicate indices).  Full tensor_ir suite
665 passed / 1 xfailed.

Out of scope (not scatter): test_view_assignment_semantics (pre-existing
view-aliasing bug, skipped on hlo) and llama_decoder_dynamo (OOM in the
128256-wide LM-head matmul, a matmul free-dim tiling limit).
…n limit

The same-last-dim reshape fast path expressed each multi-row P-tile as a
single source rectangle via flat_range_to_src_slices. When a tile's flat
range crossed a source leading-dim boundary (e.g. collapsing (3,100,8) into
(300,8)), the single-rectangle form silently truncated at the first
boundary, producing wrong data.

Add flat_range_to_src_chunks, which decomposes a contiguous flat range into
maximal rectangular sub-slices (one chunk for the aligned fast path, so no
extra DMAs there). Rewire both same-last-dim emitters (_lower_reshape_
same_last_dim and _emit_reshape_same_f) to emit one DMA per chunk, mapping
each chunk's whole rows 1:1 to consecutive output rows.

This fixes take/gather for arbitrary flattened partition extents, so the
P <= 128 or P % 128 == 0 guard in _take_dynamic is removed. The two
test_slice_extraction xfails (P=300, P=525) now pass.
…structural pad in nkigen-lite

Address three documented NotImplementedError edge cases in the nkigen-lite
op layer that previously auto-xfailed:

- diff(prepend=, append=): concat the prepend/append operands (scalar or
  array, broadcast to the array's non-diff axes) onto the input before
  differencing, matching numpy.
- dynamic_update_slice with a non-uniform numpy array: route the value
  through the constant builder, which already run-length-encodes non-uniform
  data, instead of rejecting it.
- pad modes reflect / symmetric / wrap: build them from width-1 slabs in the
  numpy source-index order, alongside the existing constant/edge modes. The
  per-axis index patterns were verified against numpy for all in-range pad
  widths.

All three are numerically validated on-device. The shared trace_mode tests
skip HLO for the pad modes and diff prepend/append, since HLO lacks both
(its pad supports only constant/edge, and its diff ignores prepend/append).
…lement copies

_emit_strided_slice copied one element at a time whenever the free-dim stride
was non-unit, producing O(num_elements) DMAs. For conv2d/conv3d im2col (each
kernel position is a spatially-strided slice) this exploded the nki_ir graph:
a single (1,16,32,32)->(1,16,14,14) stride-2 slice lowered to 9,408 ops, and a
conv2d(16->32,k5,s2) lowered to 345k ops in ~7s, then took 1-2 min in
neuronx-cc.

The DMA engine already expresses strided access natively via per-dimension
DimSlice strides (honored by the interpreter and emit_to_kb's nb.coords affine
expressions). Tile the output P-dim like the contiguous slice path and emit one
strided load + contiguous store per tile.

Results: the hot slice drops 9,408 -> 48 ops; conv2d(16->32,k5,s2) drops
345k -> 33k ops and 6.9s -> 0.23s of Python lowering. End-to-end conv2d now
compiles in 6-18s (was 1-2 min); standard conv3d cases complete in seconds.

Unskip the conv2d and conv3d nkigen-lite tests accordingly (all pass on-device).
The 1152-channel Qwen3-VL conv3d case remains skipped: it hits a separate
reshape-lowering blowup (the (Co,*K,Ci)->(Co,K*Ci) weight reshape lowers to
millions of per-row DMAs), to be addressed next.
…d-trip)

The scratch-based reshape fallback reassembles each output row from fragments
when out_f > in_f, emitting O(out_f/in_f) tiny DMAs per row. For the conv
im2col weight reshape (Co,*K,Ci)->(Co, K*Ci) with Ci=3, K*Ci=1536 (the
1152-channel Qwen3-VL conv3d), that single reshape lowered to ~1.88M ops.

When in_shape and out_shape share a leading prefix-product P, the reshape only
regroups each row's free dimension (in_f -> out_f) with the partition axis
fixed. That is a zero-copy on-chip `view` — and crucially it is hardware-legal,
since KB requires SBUF views to preserve the partition dimension. Tile P at 128
and emit one contiguous load + view + contiguous store per tile.

The Qwen weight reshape drops 1.88M -> 36 ops; conv im2col window reshapes drop
to 4 ops each. Shapes with no usable common prefix (e.g. (256,1)->(1,256)) still
use the scratch fallback. Verified correct on-device (the IR reshape suite runs
on hardware, which rejected an earlier gcd-via-view attempt that changed the
partition dim — this path does not).

The Qwen conv3d case remains skipped: its dominant cost is now a separate
weight transpose (Ci from axis 1 to last), ~258k ops, to be addressed next.
Records the remaining Qwen3-VL conv3d bottleneck: the im2col weight transpose
(Co,Ci,*K)->(Co,*K,Ci) lowers to ~258k ops because the per-tile emitter only
swaps the last two dims, iterating everything else as per-element batch.

Captures the two prototyped ideas (axis collapse; fold leading passthrough dims
into the partition with one N-D dma_transpose) and the two hardware constraints
that block the clean version — dma_transpose's restricted perm set
(2D=[1,0],3D=[2,1,0],4D=[3,1,2,0]) and the need to decompose partial merged-axis
tiles back to original-rank rectangles — plus a suggested fix sequence and a
validation checklist noting the interpreter is not sufficient (it accepted both
hardware-rejected prototypes).
…en conv3d weight)

Merge runs of consecutive source dims that appear adjacent in the output
permutation into single axes before tiling. This reduces batch iterations
dramatically for cases like (Co, Ci, *K) -> (Co, *K, Ci) where spatial
dims form a contiguous run: Qwen3-VL weight transpose drops from 258k to
32k NKI ops. Uses flat_range_to_src_chunks to correctly decompose merged-
axis tiles that straddle original dim boundaries back to HBM rectangles.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant