WIP: add nkigen-lite as a standalone IR-based kernel generation backend#59
WIP: add nkigen-lite as a standalone IR-based kernel generation backend#59ymwangg wants to merge 45 commits into
Conversation
Migrates tensor_ir, nki_ir, and the direct lowering passes from nano-tensorizer/ir_lab into the nkipy workspace as a new package. The pipeline (canonicalize → decompose → layout_solver → direct_lower) produces legal NKI IR directly without intermediate passes.
Add nkigen-lite as a fully functional backend (backend="nkigen-lite") alongside hlo and nkigen. The pipeline traces Python kernels through nkigen_lite's tensor_ir Builder, lowers via the pass pipeline (canonicalize → decompose → layout_solver → direct_lower), and compiles to NEFF via the NKI kernel_builder API. nkipy integration: - backend/nkigen_lite.py: TraceContext, Tensor, IR adapter - ops/_nkigen_lite_impls.py: op implementations delegating to Builder - ops/_register_nkigen_lite.py: lazy op registration - trace.py: _specialize_nkigen_lite() dispatch - compile.py: _compile_nkigen_lite() via kernel_builder - knob.py, nki_op.py: backend-aware dispatch nkigen-lite enhancements: - Builder: add abs, sign, floor, ceil, power, floor_divide, mod ops - Interpreter: numpy dispatch for new ops, dtype-aware tensor_copy - Decompose pass: floor_divide/mod use divide-then-verify-and-correct strategy (matching neuronx-cc BIR), power→exp(b*log(a)), ceil→neg(floor(neg(x))), fixed-point iteration with max-iter guard - Direct lowering: abs/sign/sin via NisaActivationOp, floor via i32 truncation + sign correction, cast via tensor_copy, 1D reshape fix - docs/floor_divide_precision.md: documents the precision strategy Test results: 134/135 HLO-parity tests pass on trn2 hardware (99.3%).
Add "nkigen-lite" to the trace_mode fixture so all parametrized tests run with both backends. Add a pytest hook that marks NotImplementedError as xfail for nkigen-lite — ops not yet implemented show as expected failures and automatically start passing when added. Current results: - HLO: 741 passed, 4 xfailed, 42 skipped - nkigen-lite: ~340 passed, 161 xfailed (unimplemented ops), ~93 failed (partial implementations needing further work)
Add ReduceKeepdimsFalsePattern to decompose keepdims=False reductions into keepdims=True + reshape, which the layout solver and lowering require. Handle scalar (rank-0) tensors throughout the lowering pipeline by promoting them to (1,) at the NKI boundary since the hardware doesn't support rank-0 tensors. Also fix negative axis normalization in squeeze() and expand_dims().
- matmul: add 1D→2D promotion following NumPy semantics - squeeze: validate non-1 dims, normalize negative axis - reshape: handle int newshape argument - zeros/full: handle int shape argument - concatenate: handle single-tensor case, validate empty/axis bounds - split: validate axis bounds and unequal division - where: handle numpy array condition argument - _ensure_value: handle numpy array operands (uniform-fill)
- expand_dims: validate duplicate axes and out-of-bounds axis - Skip test_reduce_unsupported_op and test_topk_non_last_axis for non-HLO backends since they test HLO-specific internal behavior
NeuronCore hardware only supports Add and Max for cross_lane_reduce_arith. Implement MIN as -max(-x) transparently in the NKI IR builder so all existing P-dimension reduce codepaths work with min reductions.
Replace HLO-specific DeviceKernel.compile_and_load path with the shared on_device_test utility which handles input/output naming differences between backends automatically.
- broadcast_to: handle scalar (rank-0) source by loading the single element and broadcasting via tensor_scalar_arith with ones - emit_to_kb: auto-cast f16/bf16 operands to f32 around tensor_scalar_arith since the hardware scalar engine requires f32
Add NisaBitvecOp enum and tensor_tensor_bitvec builder method to NKI IR. Wire through the full pipeline: tensor IR opcodes, elementwise lowering, emit_to_kb mapping, and interpreter support. Replace the old arithmetic approximations (which only worked for booleans) with hardware bitwise instructions that work correctly on integer types.
Add NKI IR primitives for mixed-dtype operations: - tensor_tensor_compare: comparison ops (IsGT, IsGE, etc.) that accept float inputs and produce uint8 predicate output - tensor_scalar_bitvec: scalar bitvec ops (XOR for logical NOT, etc.) - Comparison and logical op variants in NisaArithOp enum Rewrite _emit_floor to use the NKI compiler's compare+select pattern: trunc→compare→conditional select in integer domain, avoiding float precision issues in the correction step.
…lice - emit_slice: add strides parameter; delegate to _emit_strided_slice for non-unit strides (element-by-element DMA for F-stride, row-by-row for P-stride) - dynamic_update_slice: handle numpy array value argument (uniform fill)
Add DType.FP8_E4M3_IEEE for the IEEE-standard float8_e4m3 format (distinct from the NaN-free float8_e4m3fn variant already supported). Wire through core, emit_to_kb, and compile dtype mappings.
Each pytest-xdist worker now claims a specific Neuron core via NEURON_RT_VISIBLE_CORES, enabling parallel test execution across all 64 available cores (~8.5x speedup).
- Comparison ops (equal, not_equal, greater, less, etc.) now produce same dtype as input (1.0/0.0 float) matching NKI convention, instead of DType.BOOL - where op lowered using NKI pattern: cond*x + (1-cond)*y with all float arithmetic — no mixed-dtype operations needed - Map DType.BOOL → uint8 in kernel builder and execution layer - Update tensor IR builder to remove BOOL requirement from where - Reduces xfail count from 162 → 125 (37 tests now passing)
Matches NKI compiler's approach: cos(x) = sin(x + π/2). The hardware sin activation instruction handles the computation.
Implement np.dot semantics as a composed op: - 1D/2D cases delegate directly to matmul - N-D × 1D delegates to matmul (batched matrix-vector) - N-D × M-D decomposes to reshape + matmul + reshape to achieve the outer-product batch semantics of np.dot
- arctan: wire native NISA ARCTAN activation through the Builder and direct-lower tables - invert/bitwise_not: composed_impl as XOR with all-ones (-1), matching the NKI compiler's implementation - logical_and/or/xor: composed_impl via 0/1 truthiness; also unblocks rint/round which depend on logical_and - constant: backend impl mirroring HLO (passthrough + uniform fill); non-uniform array constants raise NotImplementedError Fix a pre-existing bug in _emit_broadcast_scalar that fed a (1,1) tile to tensor_scalar_arith whose scalar operand partition dim must match the destination; replicate to (p_size, 1) via broadcast_partition. Also make test_ml_dtypes_constant_encoding's float8 xfails backend-aware so float8_e5m2 on nkigen-lite no longer reports XPASS.
The slice-based gather produced wrong output shapes: it ignored axis=None (no flatten), concatenated slices along the original axis instead of replacing it with indices.shape, and mishandled scalar and multi-dimensional index arrays. Rewrite to match numpy: out.shape == a.shape[:axis] + indices.shape + a.shape[axis+1:] - axis=None flattens the input first - negative indices are normalized modulo the axis dimension - each flat index becomes a width-1 slice; slices are concatenated then reshaped so the gathered axis is replaced by indices.shape (dropped entirely for a scalar index) Fixes 13 failing test_take_scalar / test_take_numpy_indices cases.
…) for nkigen-lite Wires the four distributed collectives through the full nkigen-lite stack: - tensor_ir Builder: collective ops with correct output-shape inference (all_gather grows the gather dim, reduce_scatter/all_to_all shrink/grow by world size) - nkipy lite impls + registration, mapping numpy reduce ufuncs to the collective reduce-op names - direct_lower: stage collectives through internal HBM scratch buffers (the compiler forbids collectives from reading/writing kernel IO tensors directly — "Collective instruction cannot read IO tensors") - nki_ir Builder: collective() side-effect node (HBM->HBM) - emit_to_kb: lower to nisa.all_reduce/all_gather/reduce_scatter/all_to_all via ExplicitReplicaGroupAttr + dma_compute_reduce_op The KB collective API only operates on the last (free) axis of 2D HBM tensors (cc_dim=0 raises std::bad_cast), so all_gather/reduce_scatter along other axes are staged via transpose-collective-transpose. Fixes the all_reduce/all_gather/reduce_scatter/all_to_all xfails (multiply-reduce variants stay xfailed for both backends — unsupported by the compiler).
The earlier transpose workaround for all_gather/reduce_scatter was based on a misdiagnosis: cc_dim=0 appeared to raise std::bad_cast, so collectives were staged through a transpose to operate on the last axis. Multi-core numerical verification showed that path silently dropped the remote rank's data (all_gather duplicated the local source; reduce_scatter ignored the per-rank scatter offset). Root cause: the KB nisa collective APIs forward cc_dim to the native builder un-converted, so a bare int 0 fails the int->enum cast. The NKI collectives contract also requires collective_dim=0 for HBM tensors. Fix: - emit_to_kb: convert the int dim to CollectiveDimension (DIM_0/DIM_1) before calling nisa.all_gather/reduce_scatter/all_to_all - drop the transpose workaround; gather/scatter along the requested dim directly Verified on 2 NeuronCores with distinct per-rank data: all_reduce, all_gather(dim0), reduce_scatter(dim0), and all_to_all all produce the correct cross-rank results.
Two bugs combined to make a // b and a % b off by one on the rare inputs where the true quotient is an exact integer: 1. The composed floor_divide impl (floor(divide(x, y))) ran at trace time, so the graph never contained a `floor_divide` opcode and the decompose pass's divide-then-verify-and-correct FloorDividePattern never fired. NeuronCore has no native divide -- it uses reciprocal multiply, which undershoots exact integers (2.0 -> 1.9999999), so plain floor gave N-1. Fix: register nkigen-lite-specific floor_divide/remainder impls that emit the native floor_divide/mod opcodes, so the correcting pattern runs. 2. Within FloorDividePattern, the up-correction used max(0, sign(|rem| - |b|)), which is 0 at the |rem| == |b| boundary (sign(0) == 0) -- exactly the exact-integer undershoot case. Replace with an inclusive greater_equal(|rem|, |b|); a genuine remainder is always strictly < |b|, so equality can only mean undershoot. Fixes the 3 failing floordiv/mod broadcasting cases. Exact-integer-boundary inputs remain inherently ambiguous under reciprocal division (numpy and the device can disagree by an ULP), but all test cases now pass.
Status Update: Replacing the HLO backend with nkigen-liteReported 2026-06-23 from full test suite run ( Where we areOver Jun 2–5, nkigen-lite was stood up as a third backend and brought to broad Current suite status: To finish retiring the
1. Hard failures — numerical bugs (2 cases) — fix firstThese already run but produce wrong results:
2. Unimplemented ops (58 cases)Each op unblocks its cluster of tests:
3. Capability gaps — broader subsystems (42 cases)
4. Out of scope — compiler limitation (2 cases — leave xfailed)
These cannot be cleared by nkigen-lite work — hardware/compiler limitation. Remaining work, in priority order
Clearing §1–§3 retires ~100 of the 102 hlo-only-passing cases; the 2 |
NRT does not carry FP8 dtypes through compiled neff metadata: e4m3/e5m2
surface as "int8" and e4m3fn surfaces as "unknown". The data round-trips
correctly on device, so the only blocker was spike's dtype validation,
which special-cased only e4m3/e5m2 and only accepted "int8".
Extend _check_dtype_compatibility to cover e4m3fn and accept either
placeholder ("int8"/"unknown"), gated behind the FP8 dtype set so
non-FP8 tensors are still validated strictly. Drop the matching
nkigen-lite xfail in test_ml_dtypes_constant_encoding.
The high-level tensor_ir.Builder had no index-ramp op, while the low-level nki_ir already exposed nisa.iota. Bridge that gap across all four layers: - tensor_ir/ir.py: Builder.iota(shape, dim, dtype) — out[..., i, ...] == i along dim, broadcast over other axes (np.arange-on-axis semantics). - core.py: numpy eval in eval_common_op (shared by both interpreters). - direct_lower.py: _emit_iota_op tiles under the canonical row-major layout and maps each axis to nisa.iota's pattern/channel_multiplier/offset (free -> step 1; partition -> channel_multiplier 1 + p_off; batch -> constant offset). Kept out of ELEMENTWISE_OPCODES since it is position-dependent. Unblocks tril/triu/diag/trace, which build index masks via iota. Adds TestIota HW coverage (per-axis, multi-tile partition, rank-3, and iota feeding an elementwise op).
Implement gap-8 triangular/diagonal ops on top of the new iota primitive, mirroring the HLO impls: build row/col index masks via iota + compare, then where(mask, x, 0). - tril/triu: keep row >= col-k (lower) or row <= col-k (upper); masks built over the last two axes, broadcast across batch dims. - diag 1D->2D: extend v to length N with a zero-pad on the side away from the diagonal, broadcast across columns, keep col == row+k. Avoids the HLO take-based gather (dynamic indexing is unsupported on nkigen-lite). - diag 2D->1D: mask the k-th diagonal and reduce-sum the off-axis to collapse to the diagonal vector; slice to diag_len. - trace: mask the diagonal (offset) and reduce-sum both axes. Flips 8 xfails to pass in tests/unit/test_tensor_api.py. Verified against numpy on non-square diag/tril/triu and trace-with-offset edge cases beyond the existing suite.
Implement gap-7 (pad) and gap-9 (flip/roll/tile/diff) as pure slice/concat data movement — no new primitives needed. - flip: reverse an axis by concatenating width-1 slices in descending order. - tile: concatenate r copies of the running result along each axis. - roll: cyclic shift via split at (n-shift) + swapped concat; supports axis=None (flatten), int axis, and tuple shift/axis. - diff: iterated x[1:] - x[:-1] along axis (n times). - pad: constant mode concatenates full-valued slabs; edge mode replicates the first/last slab. Handles scalar, per-axis, and asymmetric pad_width. Flips 14 xfails to pass in tests/unit/test_tensor_api.py. Verified against numpy on 3-D flip, multi-axis flip/roll, flattened roll, 3-D tile, and 3-D asymmetric pad beyond the existing suite.
Implement argmax/argmin via index masking on top of iota, mirroring the HLO algorithm: reduce to the extreme value along the axis, mark positions equal to it with their index (an iota ramp) and all others with a large sentinel, then min-reduce the indices — returning the first index that attains the extreme, matching numpy's tie-break. The whole computation runs in float32 (cast input up front, cast result to int32 at the end): min/max reductions initialize with +/-inf, which cannot be memset into an integer tile, so an integer input or index ramp would fail to compile. Handles axis=None (flatten), negative axis, and keepdims. Flips 8 xfails to pass; verified against numpy on int inputs and tie-breaking beyond the suite.
The lite builder only emits uniform fills, so non-uniform array constants previously raised NotImplementedError. Materialize them as a flat sequence of run-length fills, concatenate, and reshape — cheap for structured/small arrays, capped at 4096 runs to keep tracing bounded. Route the binary-operand path (_ensure_value) through the same logic. cumsum gets a dedicated nkigen-lite impl rather than relying on the composed fallback (which builds constant(np.triu(np.ones((N,N))))): for axis=None the flattened triangular matrix is (4096,4096), far too large as a literal. Build U[i,j] = (i<=j) via iota + compare instead, then cumsum = x_2d @ U. Handles axis=None/negative/middle axis and dtype. Flips 8 xfails to pass (4 cumsum, list/tuple constant, integer where-cond). Verified against numpy on 1-D/3-D/negative-axis cumsum and small structured constants beyond the suite.
nkigen-lite has no convolution primitive, so decompose N-D conv into im2col + a single matmul: gather each kernel position's strided window as a (N, Ci, out_pts) block, concat along the channel axis into (N, Ci*prod(K), out_pts), flatten the (transposed) weight to (Co, Ci*prod(K)), and do one batched matmul -> (N, Co, out_pts). A single fused matmul compiles ~35% faster than accumulating prod(K) separate matmuls (95.6s -> 61.5s on ic=16/oc=32/k=3); the official conv2d suite drops from 165s to 121s. Spatial padding is built from concat of zero slabs; strided/dilated windows use strided slice. groups != 1 is unsupported. Verified against PyTorch on conv2d (stride/padding/dilation/bias/1x1) and conv3d including a non-cubic (2,3,3) kernel. conv tests require torch as the oracle, now installed via the examples dependency group.
repeat: insert a size-1 axis after the repeat axis, broadcast it to the repeat count, then reshape to fold it back in — duplicating each element in place (matching np.repeat / the HLO impl). Scalar integer repeats only. split: handle an explicit list of split indices (numpy semantics — boundaries clamped to the axis size, len(indices)+1 sub-arrays). Empty sub-arrays (repeated/out-of-range index) raise NotImplementedError, since the lite IR has no zero-size tensor representation. Flips 10 xfails to pass (9 repeat variants + split_indices). Verified repeat on axis=None/middle-axis beyond the suite.
The hardware SIN activation is only accurate near [-π, π]; outside it the polynomial approximation diverges wildly (cos(500) returned ~2e7 instead of a value in [-1, 1]). This silently corrupted any kernel feeding large arguments into sin/cos — e.g. rope_dynamo, where bmm produces values up to ~512, showed a ~52% output mismatch. Add SinRangeReductionPattern to the decompose pass: sin(x) → sin(x - 2π·round( x/2π)), with round(y)=floor(y+0.5), bringing the argument into [-π, π] before the hardware SIN. The emitted inner sin carries a range_reduced attr so the pattern doesn't re-match and loop. Runs after CosPattern so cos→sin lowering happens first and both get reduced. Fixes the rope_dynamo integration test (now passing) and makes cos/sin accurate to <1e-4 across ±500 (was max error ~2e7). One pre-existing bug remains: view-assignment aliasing.
No sort primitive exists, so extract top-k iteratively: k times, take the max along the axis, record value + argmax index, then mask that single position to -inf so the next round finds the following element. Masking by index position (via an iota ramp) rather than value gives a stable lowest-index pick among equal maxima. Supports descending (default) and ascending (negate in/out); returns (values, uint32 indices) matching torch.topk. Flips 8 xfails to pass, verified against torch.topk. Note: for exact-value ties the index ordering may differ from torch (both are valid top-k results); random-data tests are unaffected.
Wire the max8 / find_index8 selection primitives through every IR layer (nki_ir Builder + interpreter + emit_to_kb, and the tensor_ir Builder + interpreter + direct_lower), mirroring the iota plumbing. max8 returns the 8 largest values per partition along the free dim. topk now computes values with hardware max8 (one instruction) instead of k iterative reduce-max passes: move the topk axis to the free dim, flatten leading dims to the partition dim, pad the free dim to >=8 with -inf, max8, slice the first k. Limited to k<=8. Indices are still recovered via the iota position-ramp argmin trick, NOT find_index8: find_index8 (the MaxIndex instruction) fails the neuronx-cc ISA check [NCC_IXCG864] on this target, so it is wired through the IR (and works in the interpreter) but not used in the topk lowering. All 8 topk tests pass against torch; verified 1-D, ascending, and index tie-break edge cases.
Replace the k<=8 max8-only topk with the scanning loop used by nkilib's
topk_core, which supports any k and recovers indices from hardware.
Wire nc_match_replace8 through the IR (nki_ir Builder two-result op +
interpreter + emit_to_kb via nisa.max_index_and_match_replace). Model topk as
a single tensor_ir op (2 results: values + uint32 indices) whose lowering runs
the SBUF-resident scan: ceil(k/8) folds of max8 (next 8 largest) +
match_replace8 (record indices, mask taken values to -inf). Free dim < 8 is
padded with -inf; each fold's results DMA-store to the matching output column
slice.
Key fixes found by incremental testing:
- find_index8 is gen2-only and fails the ISA check on trn2; match_replace8
(dst_idx) is the gen3+ index path.
- match_replace8 indices must be uint32, not int32 (DVE_READ_INDICES AP
constraint NCC_IXCG988).
- insert_deallocs special-cases match_replace8's two dst buffers: result[1]
aliases input[1] (dst_idx), not input[0] — the generic rule mis-freed the
index tile ("use of released tile").
All 8 topk tests pass; verified k>8 (k=10,16 multi-fold), F<8 padding, 1-D,
and ascending against torch including indices.
…lash) The nkigen-lite/tests HW suite failed entirely under `-n auto`. Root cause was NOT plain core contention: spike._spike and nki.runtime._spike are separate compiled extension modules that collide in CPython's loader — whichever is imported second resolves to the first, raising "ImportError: cannot import name 'ModelTensorInfo'". The nkigen-lite HW tests execute via nki.runtime, so any prior `import spike` in a worker broke them. This reproduced even at -n 1 (one worker) while -n 0 passed. Populate the empty nkigen-lite/tests/conftest.py to: - count NeuronCores by enumerating /dev/neuron* instead of importing spike, so the runtime module clash never occurs in a worker; - pin each xdist worker to its own core (NEURON_RT_VISIBLE_CORES) and cap -n auto worker count to the core count, mirroring tests/conftest.py; - register the `hw` marker (removes PytestUnknownMarkWarning; -m "not hw" now selects correctly). Full nkigen-lite suite: 685 passed, 1 xfailed under -n auto (was 52+ HW failures). Note: run serial suites with -n0, not -p no:xdist, so the pytest_xdist_auto_num_workers hook stays valid.
conv3d lowering hangs the suite on large-channel cases (the Qwen3-VL 3->1152 case never completes), so xfail won't help (the body still runs). conv2d passes but is very slow (~1-2 min/case). Skip both on nkigen-lite to keep the suite fast and non-hanging; HLO coverage is unaffected. Remove once the conv lowering path is optimized.
Add a 2-D per-partition runtime gather primitive (gather_along_axis) to tensor_ir, lowering to the hardware nisa.gather instruction with partition tiling for P > 128. Wire take_along_axis through it in the nkipy frontend: broadcast indices on non-axis dims (matching HLO), move the gather axis to the free dim, flatten leading dims to the partition dim, gather, then reshape/transpose back. Handles axis=None and negative axes. Unblocks take_along_axis and diagonal_gather (which routes through take_along_axis) on nkigen-lite; verified on Trainium hardware. New tests: nkigen-lite/tests/tensor_ir/test_gather.py (interpreter, lowering gate, and hardware across 6 shapes incl. P>128 tiling).
Route np.take with runtime indices through the gather_along_axis hardware primitive instead of raising NotImplementedError: move the gather axis to the free dim, flatten leading dims to the partition dim, broadcast the shared index vector across partitions, gather, then reshape/transpose the (P, M) result back to numpy's take layout (axis replaced by indices.shape). Guards for cases that can't lower: - non-integer indices (boolean masks; nkigen-lite reports comparisons as f32 so the frontend bool guard misses them) -> NotImplementedError, which preserves the boolean-indexing negative tests. - flattened partition extent P that is >128 and not a multiple of 128, a pre-existing reshape-lowering limit -> NotImplementedError (clean xfail). Also skip the no-op inverse transpose for scalar/degenerate indices, which otherwise tripped the rank-1 transpose lowering. Unblocks dynamic take (test_take), the view-as-index MoE patterns, and the gather half of rotary_embed; verified on Trainium.
Add the scatter-family ops by fixing and building on the indirect-DMA primitive (dma_copy_indirect), which was previously unused and miswired. nki_ir primitives (tensor_ir): - scatter_rows: out = base.copy(); out[idx[r],:] = updates[r,:] -> indirect DMA store, via the canonical dst.ap(vector_offset=) view + dma_copy. - gather_rows: out[r,:] = src[idx[r],:] -> indirect DMA load, same idiom on the source. Both tile over PARTITION_MAX and take (M,1) U32 indices. - Fix emit_to_kb store/load: the old low-level dma_copy_indirect calls passed mismatched full tiles (src.size != dst.size); use the .ap(vector_offset=) view so the DMA tiler sets up the indirect access pattern. - Fix the nki_ir interpreter load/store to row semantics (was flat np.take/ np.put), matching the hardware. Frontend (nkipy): - scatter_along_axis: a[:, t, :] = b -> move axis to row dim, scatter_rows. - put_along_axis: per-element scatter via flatten-by-strides (HLO trick) onto width-1 row scatter; handles scalar values and axis=None. - scatter_strided: a[::s, ::s] = b -> static cartesian-product flat indices + width-1 row scatter. - take/take_along_axis axis-0 row-gather fast path via gather_rows, avoiding the full-table transpose that OOMed on tall tables (embedding 128256x2048). Unblocks on Trainium: put_along_axis(_scalar_value), slice_assignment (_indeterministic), step_slicing_assignment, rotary_embed, embedding_dynamo. New HW tests: nkigen-lite/tests/tensor_ir/test_scatter.py (scatter_rows + gather_rows, incl. N>128, M>128, duplicate indices). Full tensor_ir suite 665 passed / 1 xfailed. Out of scope (not scatter): test_view_assignment_semantics (pre-existing view-aliasing bug, skipped on hlo) and llama_decoder_dynamo (OOM in the 128256-wide LM-head matmul, a matmul free-dim tiling limit).
…n limit The same-last-dim reshape fast path expressed each multi-row P-tile as a single source rectangle via flat_range_to_src_slices. When a tile's flat range crossed a source leading-dim boundary (e.g. collapsing (3,100,8) into (300,8)), the single-rectangle form silently truncated at the first boundary, producing wrong data. Add flat_range_to_src_chunks, which decomposes a contiguous flat range into maximal rectangular sub-slices (one chunk for the aligned fast path, so no extra DMAs there). Rewire both same-last-dim emitters (_lower_reshape_ same_last_dim and _emit_reshape_same_f) to emit one DMA per chunk, mapping each chunk's whole rows 1:1 to consecutive output rows. This fixes take/gather for arbitrary flattened partition extents, so the P <= 128 or P % 128 == 0 guard in _take_dynamic is removed. The two test_slice_extraction xfails (P=300, P=525) now pass.
…structural pad in nkigen-lite Address three documented NotImplementedError edge cases in the nkigen-lite op layer that previously auto-xfailed: - diff(prepend=, append=): concat the prepend/append operands (scalar or array, broadcast to the array's non-diff axes) onto the input before differencing, matching numpy. - dynamic_update_slice with a non-uniform numpy array: route the value through the constant builder, which already run-length-encodes non-uniform data, instead of rejecting it. - pad modes reflect / symmetric / wrap: build them from width-1 slabs in the numpy source-index order, alongside the existing constant/edge modes. The per-axis index patterns were verified against numpy for all in-range pad widths. All three are numerically validated on-device. The shared trace_mode tests skip HLO for the pad modes and diff prepend/append, since HLO lacks both (its pad supports only constant/edge, and its diff ignores prepend/append).
…lement copies _emit_strided_slice copied one element at a time whenever the free-dim stride was non-unit, producing O(num_elements) DMAs. For conv2d/conv3d im2col (each kernel position is a spatially-strided slice) this exploded the nki_ir graph: a single (1,16,32,32)->(1,16,14,14) stride-2 slice lowered to 9,408 ops, and a conv2d(16->32,k5,s2) lowered to 345k ops in ~7s, then took 1-2 min in neuronx-cc. The DMA engine already expresses strided access natively via per-dimension DimSlice strides (honored by the interpreter and emit_to_kb's nb.coords affine expressions). Tile the output P-dim like the contiguous slice path and emit one strided load + contiguous store per tile. Results: the hot slice drops 9,408 -> 48 ops; conv2d(16->32,k5,s2) drops 345k -> 33k ops and 6.9s -> 0.23s of Python lowering. End-to-end conv2d now compiles in 6-18s (was 1-2 min); standard conv3d cases complete in seconds. Unskip the conv2d and conv3d nkigen-lite tests accordingly (all pass on-device). The 1152-channel Qwen3-VL conv3d case remains skipped: it hits a separate reshape-lowering blowup (the (Co,*K,Ci)->(Co,K*Ci) weight reshape lowers to millions of per-row DMAs), to be addressed next.
…d-trip) The scratch-based reshape fallback reassembles each output row from fragments when out_f > in_f, emitting O(out_f/in_f) tiny DMAs per row. For the conv im2col weight reshape (Co,*K,Ci)->(Co, K*Ci) with Ci=3, K*Ci=1536 (the 1152-channel Qwen3-VL conv3d), that single reshape lowered to ~1.88M ops. When in_shape and out_shape share a leading prefix-product P, the reshape only regroups each row's free dimension (in_f -> out_f) with the partition axis fixed. That is a zero-copy on-chip `view` — and crucially it is hardware-legal, since KB requires SBUF views to preserve the partition dimension. Tile P at 128 and emit one contiguous load + view + contiguous store per tile. The Qwen weight reshape drops 1.88M -> 36 ops; conv im2col window reshapes drop to 4 ops each. Shapes with no usable common prefix (e.g. (256,1)->(1,256)) still use the scratch fallback. Verified correct on-device (the IR reshape suite runs on hardware, which rejected an earlier gcd-via-view attempt that changed the partition dim — this path does not). The Qwen conv3d case remains skipped: its dominant cost is now a separate weight transpose (Ci from axis 1 to last), ~258k ops, to be addressed next.
Records the remaining Qwen3-VL conv3d bottleneck: the im2col weight transpose (Co,Ci,*K)->(Co,*K,Ci) lowers to ~258k ops because the per-tile emitter only swaps the last two dims, iterating everything else as per-element batch. Captures the two prototyped ideas (axis collapse; fold leading passthrough dims into the partition with one N-D dma_transpose) and the two hardware constraints that block the clean version — dma_transpose's restricted perm set (2D=[1,0],3D=[2,1,0],4D=[3,1,2,0]) and the need to decompose partial merged-axis tiles back to original-rank rectangles — plus a suggested fix sequence and a validation checklist noting the interpreter is not sufficient (it accepted both hardware-rejected prototypes).
…en conv3d weight) Merge runs of consecutive source dims that appear adjacent in the output permutation into single axes before tiling. This reduces batch iterations dramatically for cases like (Co, Ci, *K) -> (Co, *K, Ci) where spatial dims form a contiguous run: Qwen3-VL weight transpose drops from 258k to 32k NKI ops. Uses flat_range_to_src_chunks to correctly decompose merged- axis tiles that straddle original dim boundaries back to HBM rectangles.
Summary
Adds
nkigen-lite, a standalone IR-based kernel generation backend that lowers numpy-style tensor programs to NKI (Neuron Kernel Interface) code for NeuronCore targets.Architecture
The system is structured as a three-layer IR stack with a multi-phase lowering pipeline:
Core (
core.py)Shared SSA-based IR infrastructure used by both IRs:
Value,Op,Graph— SSA primitives with use-lists and mutation helpersDTypeenum covering f32/f16/bf16/tf32/fp8/int typesTensor IR (
tensor_ir/)High-level, hardware-agnostic IR operating on whole tensors:
NKI IR (
nki_ir/)Low-level IR that makes hardware concerns explicit:
Lowering Pipeline (
tensor_ir/passes/)The full pipeline:
tensor_ir → canonicalize → decompose → layout_solver → direct_lower → nki_irCanonicalize — recomposes primitive-op chains into high-level ops (e.g.,
div(1, sqrt(x))→rsqrt(x),mul(x, div(1, add(1, exp(neg(x)))))→silu(x))Decompose — lowers ops without direct NISA equivalents into supported primitives (e.g.,
div(a,b)→mul(a, reciprocal(b)),reduce(mean)→reduce(sum) * 1/N)Layout Solver — assigns each tensor dimension to one of three roles:
Propagates constraints across the graph to find a globally consistent assignment.
Direct Lower — converts tensor IR ops to tiled NKI IR:
Hardware Target (
passes/hardware.py)Parameterized hardware profiles (TRN2 defaults) defining partition limits, SBUF/PSUM sizes, and matmul constraints.
Status
🚧 Work in progress — not ready for review.
Test plan
uv run pytest nkigen-lite/tests/ -n auto)