From b0f5cb453df22dd2568af902086a74639dc3635a Mon Sep 17 00:00:00 2001 From: xniwangaws Date: Wed, 3 Jun 2026 06:48:09 +0000 Subject: [PATCH 1/2] contrib: Gemma-4-26B-A4B-it port (MoE, TP=8, BF16) Port of google/gemma-4-26B-A4B-it (~25.2B/3.8B active MoE, 30 layers, 8 active / 128 total + 1 shared expert via parallel dense MLP). Validated on trn2.48xlarge (SDK 2.29): Stage 1 (DISABLE_MOE): compile 2.2 min, load 20.9s, NEFF 17 MB Stage 2 (MoE on): compile 19.7 min, load 29.1s, NEFF 297 MB Stage 3 (inference): TTFT 309.5 ms, TPOT 8.79 ms, 114 tok/s Borrows NKI flash attn (d=256 SWA + d=512), KV cache manager, softcap LM head, scaled embedding, and Gemma4 RMSNorm flavours from PR #106 (gemma-4-31B-IT). 26B-A4B-specific: MoE block with parallel dense MLP at decoder layer (HF source lines 1429-1441), dual-input MoE forward (router sees raw residual, experts see post_feedforward_layernorm_2(residual)), 128 experts with top-k=8 + per-expert-scale routing. --- .../gemma-4-26b-a4b-it/DIFF_FROM_PR106.md | 104 ++ contrib/models/gemma-4-26b-a4b-it/README.md | 176 +++ .../scripts/smoke_compile.py | 139 ++ .../scripts/smoke_inference.py | 228 +++ .../models/gemma-4-26b-a4b-it/src/__init__.py | 18 + .../src/configuration_gemma4_neuron.py | 111 ++ .../src/modeling_gemma4_neuron.py | 1356 +++++++++++++++++ .../gemma-4-26b-a4b-it/src/ndxi_patch.py | 505 ++++++ .../src/nki_flash_attn_d256_swa.py | 950 ++++++++++++ .../src/nki_flash_attn_large_d.py | 346 +++++ .../gemma-4-26b-a4b-it/test/__init__.py | 0 .../test/integration/__init__.py | 0 .../test/integration/test_model.py | 127 ++ .../gemma-4-26b-a4b-it/test/unit/__init__.py | 0 14 files changed, 4060 insertions(+) create mode 100644 contrib/models/gemma-4-26b-a4b-it/DIFF_FROM_PR106.md create mode 100644 contrib/models/gemma-4-26b-a4b-it/README.md create mode 100644 contrib/models/gemma-4-26b-a4b-it/scripts/smoke_compile.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/scripts/smoke_inference.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/src/__init__.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/src/configuration_gemma4_neuron.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/src/modeling_gemma4_neuron.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/src/ndxi_patch.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/src/nki_flash_attn_d256_swa.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/src/nki_flash_attn_large_d.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/test/__init__.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/test/integration/__init__.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/test/integration/test_model.py create mode 100644 contrib/models/gemma-4-26b-a4b-it/test/unit/__init__.py diff --git a/contrib/models/gemma-4-26b-a4b-it/DIFF_FROM_PR106.md b/contrib/models/gemma-4-26b-a4b-it/DIFF_FROM_PR106.md new file mode 100644 index 00000000..43996c7b --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/DIFF_FROM_PR106.md @@ -0,0 +1,104 @@ +# Diff vs PR #106 (gemma-4-31B-IT) + +This port shares the Gemma 4 attention / norm / softcap / RoPE machinery +with [PR #106](https://github.com/aws-neuron/neuronx-distributed-inference/pull/106) +(Jim Burtoft, gemma-4-31B-IT). The intent of this diff is to make review +easy by listing exactly what is **identical**, what is **adapted**, and +what is **new for the 26B-A4B MoE variant**. + +## Summary + +| Category | File | Status | +|---|---|---| +| NKI sliding-window flash attention (head_dim=256) | `src/nki_flash_attn_d256_swa.py` | **Identical** to PR #106 | +| NKI flash attention for head_dim>128 | `src/nki_flash_attn_large_d.py` | **Identical** to PR #106 | +| NxDI runtime patches | `src/ndxi_patch.py` | **PR #106 + 1-line relative-import fix** | +| Modeling | `src/modeling_gemma4_neuron.py` | **Adapted** (text-only; adds MoE block + router) | +| Configuration shim | `src/configuration_gemma4_neuron.py` | New (was inline in PR #106) | +| Vision / VLM | – | **Not ported** (text-only) | + +## File-by-file + +### `src/nki_flash_attn_d256_swa.py`, `src/nki_flash_attn_large_d.py` + +Verbatim copies of PR #106 kernels. Head dimensions on the 26B-A4B variant +match the 31B-IT (SWA layers head_dim=256, global head_dim=512, GQA 2:1) +so no kernel changes are required. + +### `src/ndxi_patch.py` + +Imports the NKI flash-attention kernel through a relative import so the +patch module is self-contained inside this port directory: + +```python +# Prefer relative import when this module ships inside the src/ package. +from .nki_flash_attn_large_d import flash_attn_large_d +``` + +Behaviour is otherwise unchanged from PR #106. + +### `src/modeling_gemma4_neuron.py` + +**Reused 1:1 from PR #106 (renamed only):** + +- `Gemma4RMSNorm`, `Gemma4VNorm` — RMSNorm flavours. +- `Gemma4ScaledEmbedding` — `embed * sqrt(hidden_size)`. +- `SoftcappedLMHead` — `cap * tanh(x / cap)` with `cap=30.0` in fp32. +- `Gemma4KVCacheManager` — per-layer heterogeneous KV shapes. +- `NeuronGemma4Attention` — partial RoPE for global, K=V at weight level, + NKI d=256 SWA prefill, post-projection v_norm. +- Q-norm pre-scaling trick in the state-dict converter (cancels NxDI's + automatic `1/sqrt(head_dim)`). + +**26B-A4B-specific additions:** + +- `NeuronGemma4Router` — FP32 softmax + top-k + renormalise + per-expert + learned scale. Reads `scale` and `per_expert_scale` learned tensors. +- `NeuronGemma4MoEBlock` — thin wrapper around NxDI `initialize_moe_module` + that consumes the gemma4 router's `top_k_index` / `top_k_weights`. +- `NeuronGemma4DecoderLayer` — **parallel-MoE layout**: + - dense MLP and MoE branch run on the **post-norm residual** in + parallel (HF source lines 1429–1441). + - `mlp_branch + moe_branch` ⇒ `layer_scalar`-multiplied final residual. + - **Dual-input MoE forward**: the router sees the *raw* residual while + the experts see `post_feedforward_layernorm_2(residual)`. Necessary + to match the HF reference; the two pre-norm streams differ. +- `convert_hf_to_neuron_state_dict` — extended for MoE: + - Stacks per-expert `gate_up_proj.weight` and `down_proj.weight` to + shape `[num_experts, ...]` for `moe_v2`. + - Renames the gemma4 router weight (`gating.weight` ⇒ + `router.weight`). + - Wires the shared-expert weights through the dense MLP path + (`shared_experts.{gate,up,down}_proj` ⇒ `mlp.{gate,up,down}_proj`). + - Pre-scales `q_layernorm.weight` by `sqrt(head_dim)` (PR #106's + trick, kept for parity). + +**Config knobs that differ from a stock NxDI MoE:** + +- `disable_normalize_top_k_affinities=True` — gemma4 already renormalises + + applies `per_expert_scale` inside the custom router; we want NxDI to + consume our affinities verbatim. +- `router_dtype="float32"`, `router_act_fn="softmax"` — match HF + reference; underlying NxDI `RouterConfig` reads these for typing. +- `glu_mlp=True`, `glu_type="glu"` — gemma4 expert MLP is gated. + +### `src/configuration_gemma4_neuron.py` + +Lightweight HF-style config dataclass split out for static parsing. +PR #106 keeps its config inline. Splitting it lets external tools read +`hidden_size` / `num_experts` / `top_k` without importing NxDI. + +### Test layout + +`test/integration/test_model.py` mirrors PR #106's layout but is reduced +to a Stage 1 / Stage 2 / Stage 3 smoke runner (compile dense, compile +MoE, generate ≤ 8 tokens). Token-match accuracy is a follow-up. + +## What is **not** in this PR (deferred) + +- Vision / audio towers — text-only port. Use PR #106 / #109 for VLM. +- Token-match accuracy validation vs HF reference (sampling, chat + template, longer prompts). +- `seq_len > 256` — round 4 only validated 256. Longer sequence compile + is a follow-up. +- vLLM serving notebook (PR #106 has one). diff --git a/contrib/models/gemma-4-26b-a4b-it/README.md b/contrib/models/gemma-4-26b-a4b-it/README.md new file mode 100644 index 00000000..5218c871 --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/README.md @@ -0,0 +1,176 @@ +# Contrib Model: Gemma 4 26B-A4B-it + +NeuronX Distributed Inference port of `google/gemma-4-26B-A4B-it`, an MoE +text-only sibling of Gemma 4 31B-IT (PR #106). + +## Model Information + +- **HuggingFace ID:** [`google/gemma-4-26B-A4B-it`](https://huggingface.co/google/gemma-4-26B-A4B-it) +- **Model Type:** Text-only Mixture-of-Experts decoder +- **Parameters:** ~25.2B total / ~3.8B active +- **License:** Check HuggingFace model card + +## Architecture Details + +Gemma 4 26B-A4B-it shares the Gemma 4 attention + LM-head stack with PR #106 +(31B-IT) but replaces the dense FFN with a parallel-MoE block. + +| Feature | Description | +|---------|-------------| +| **Layers** | 30 decoder layers | +| **Hidden / Intermediate (dense)** | 2816 / 2112 | +| **Attention heads** | 16 attention, 8 KV (GQA 2:1) | +| **Heterogeneous attention** | SWA layers (head_dim=256) and Global layers (head_dim=512) — same as 31B-IT | +| **`attention_k_eq_v`** | Global layers share K/V projections | +| **QK / V normalisation** | RMSNorm on Q and K post-projection; V uses RMSNorm without learnable scale | +| **Partial RoPE on global** | `partial_rotary_factor=0.25` (128 of 512 dims rotated) | +| **Final logit softcap** | `30 * tanh(logits / 30)` | +| **Scaled embeddings** | `embed * sqrt(hidden_size)` | +| **MoE block** | 128 routed experts, `top_k=8`, plus 1 shared (parallel) dense MLP | +| **Router** | FP32 softmax + top-k + renormalise + per-expert learned scale | +| **Decoder layout** | dense MLP and MoE branch run in **parallel** on the post-norm residual; outputs summed before `layer_scalar` | +| **Per-layer-input embed** | `hidden_size_per_layer_input=0` (PLE disabled — differs from earlier Gemma) | + +The NKI flash-attention kernels (`nki_flash_attn_d256_swa.py` for SWA layers, +`nki_flash_attn_large_d.py` for global head_dim>128 layers) are imported +verbatim from PR #106. Head dimensions are unchanged so the kernels apply +without modification. + +## Validation Results + +**Validated:** 2026-06-03 (round 4). +**Configuration:** TP=8, batch_size=1, bfloat16, seq_len=256, LNC=2. +**Instance:** trn2.48xlarge. +**SDK:** 2.29 (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`, torch 2.9.1, NxDI 0.10.0). + +### Stage 1 — DISABLE_MOE compile + load + +Validates the attention / scaled-embed / softcap / dense-MLP path without +the MoE branch. + +| Metric | Value | +|--------|-------| +| Compile | 2.2 min (priority HLO 81 s, all HLOs 17 s) | +| Weight load | 20.85 s | +| Warmup | 0.49 s | +| NEFF artifact dir | 17 MB | +| Status | **PASS** | + +### Stage 2 — MoE-on compile + load + +Adds the 128-expert routed MoE branch (parallel to the dense MLP). + +| Metric | Value | +|--------|-------| +| Compile | 19.7 min (priority HLO 106 s, all HLOs 925 s, build 1183 s) | +| Weight load | 29.1 s | +| Warmup | 0.66 s | +| NEFF artifact dir | 297 MB | +| Status | **PASS** | + +MoE compile requires `--internal-hlo2tensorizer-options='--verify-hlo=false'` +(genericmoe v16 KB) — set in `Gemma4NeuronConfig`. + +### Stage 3 — Inference smoke + +| Metric | Value | +|--------|-------| +| Prompt | `"Hello, my name is"` (5 tokens) | +| Generated | 8 tokens, decoded `", my name is, my name is"` | +| TTFT (prefill seq_len=256) | 309.5 ms | +| TPOT | 8.79 ms | +| Throughput | 114 tok/s | +| Status | **PASS** (coherence smoke; greedy + base-style continuation, no chat template) | + +## What was reused from existing NxDI + +- `NeuronAttentionBase` — Q/K/V/o projections, KV cache, GQA sharding, mask + builders. Overrides: `apply_rotary_embedding` (partial RoPE), + `prep_qkv_tensors` (post-projection v_norm), `perform_prefill` (NKI + d=256 SWA kernel). +- `RotaryEmbedding` — instantiated per-layer with the right `dim` for + partial RoPE on global layers. +- `ColumnParallelLinear` / `RowParallelLinear` / `ParallelEmbedding` — for + dense MLP, lm_head, token embedding. +- `initialize_moe_module` (NxDI `moe_v2`) — handles expert dispatch and + sharded `gate_up_proj` / `down_proj`. We feed it our own `top_k_index` / + `top_k_weights` from the gemma4 router. +- `KVCacheManager` — subclassed to support per-layer heterogeneous shapes + (8×256 SWA vs 2×512 global, after TP sharding). +- `NeuronBaseForCausalLM` / `NeuronBaseModel` — generation loop, sampling, + weight loading. + +## What was borrowed from PR #106 (31B-IT) + +- `nki_flash_attn_d256_swa.py` — verbatim. +- `nki_flash_attn_large_d.py` — verbatim. +- `ndxi_patch.py` — verbatim, with one-line tweak so the relative import + works inside this `src/` package. +- KV cache manager with per-layer cache size mapping. +- `SoftcappedLMHead` (cap = 30.0). +- `Gemma4ScaledEmbedding` (multiplies by `sqrt(hidden_size)`). +- `Gemma4RMSNorm` / `Gemma4VNorm`. +- Q-norm pre-scaling trick (cancels NxDI's automatic `1/sqrt(head_dim)` so + that gemma4's QK-norm + scale match the HF reference). + +## What is 26B-A4B-specific + +| Class | Reason | +|---|---| +| `NeuronGemma4Router` | gemma4 router with `scale` + `per_expert_scale` learned tensors, FP32 softmax + top-k + renormalise + per-expert scale. PR #106 has no router (dense). | +| `NeuronGemma4MoEBlock` | Wraps NxDI `initialize_moe_module` and feeds it gemma4-flavoured `top_k_index` / `top_k_weights`. | +| `NeuronGemma4DecoderLayer` | Parallel-MoE layout: dense MLP and MoE branch operate on the post-norm residual concurrently; combined output = `mlp_branch + moe_branch`, then `layer_scalar`-multiplied. **Dual-input MoE forward**: the router sees the raw residual, while experts see `post_feedforward_layernorm_2(residual)` — matches HF source lines 1429–1441. | +| `convert_hf_to_neuron_state_dict` | Extends PR #106's converter with MoE weight stacking (`gate_up_proj.weight` / `down_proj.weight` shaped `[num_experts, ...]`), router weight rename, shared-expert weight wiring, and `disable_normalize_top_k_affinities` so NxDI uses our pre-computed expert affinities verbatim. | + +## Open issues / known limitations + +- **Smoke-test only**: greedy + no chat template ⇒ Stage 3 output repeats. + Token-match accuracy vs HF reference still pending (sampling + + chat-formatted prompts). +- **AutoTokenizer fix**: HF transformers ≤ 4.45 trips on gemma-4's + special-tokens list-vs-dict shape; `scripts/smoke_inference.py` falls + back to the raw `tokenizers` Rust backend. +- **Multimodal towers** (vision, audio) are **not ported** — text-only. + Use PR #106 / PR #109 for VLM. +- **Sequence length tested:** 256. Longer `seq_len` (1024 / 2048) compile + is a follow-up. +- **NxDI ≥ 0.10** required (per-layer `layer_to_cache_size_mapping` and + `get_last_kv_window` patch). +- Apply `ndxi_patch.apply_patch()` once at process start before + constructing the model class — see `scripts/smoke_compile.py`. + +## How to compile and run (on trn2.48xlarge) + +```bash +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# 1. Stage 1 — dense path only (fast smoke) +GEMMA4_DISABLE_MOE=1 PYTHONPATH=src \ + python scripts/smoke_compile.py 2>&1 | tee compile_disable_moe.log + +# 2. Stage 2 — MoE on +PYTHONPATH=src \ + python scripts/smoke_compile.py 2>&1 | tee compile_moe_on.log + +# 3. Stage 3 — generate +PYTHONPATH=src \ + python scripts/smoke_inference.py 2>&1 | tee inference.log +``` + +Environment overrides (compile and inference must agree): + +| Var | Default | Notes | +|---|---|---| +| `GEMMA4_MODEL_PATH` | `/home/ubuntu/gemma4-26b-a4b` | HF checkpoint dir | +| `GEMMA4_COMPILED_PATH` | `/home/ubuntu/gemma4-compiled` | NEFF output dir | +| `GEMMA4_TP_DEGREE` | `8` | Tensor-parallel degree | +| `GEMMA4_BATCH_SIZE` | `1` | – | +| `GEMMA4_SEQ_LEN` | `256` | Compile-time max seq | +| `GEMMA4_DISABLE_MOE` | `0` | `1` ⇒ dense smoke | +| `GEMMA4_MOE_EP_DEGREE` | `1` | Keep at 1 unless `BS ≥ 32` | +| `GEMMA4_MOE_TP_DEGREE` | `` | Match `tp_degree` | + +## Diff vs PR #106 + +See [`DIFF_FROM_PR106.md`](DIFF_FROM_PR106.md) for a structural diff against +Jim Burtoft's 31B-IT port — the canonical review companion. diff --git a/contrib/models/gemma-4-26b-a4b-it/scripts/smoke_compile.py b/contrib/models/gemma-4-26b-a4b-it/scripts/smoke_compile.py new file mode 100644 index 00000000..f8b3ff0b --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/scripts/smoke_compile.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +"""Smoke compile script for Gemma-4-26B-A4B-it on Trainium 2. + +Following PR #106's `test/integration/test_model.py` shape. Run on a trn2 +host with the NxDI venv activated. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + cd contrib/models/gemma-4-26b-a4b-it + # IMPORTANT: trn2.3xlarge defaults to LNC=2 (4 logical cores). If you + # want TP=8 you MUST set NEURON_LOGICAL_NC_CONFIG=1 first; otherwise + # the runtime fails with c10::Error inside NeuronAllocators::Get when + # the spmd state initializer tries to copy zeros to ranks 4..7. + NEURON_LOGICAL_NC_CONFIG=1 PYTHONPATH=src python scripts/smoke_compile.py 2>&1 | tee ~/compile.log + +Environment overrides: + GEMMA4_MODEL_PATH (default: /home/ubuntu/gemma4-26b-a4b) + GEMMA4_COMPILED_PATH (default: /home/ubuntu/gemma4-compiled) + GEMMA4_TP_DEGREE (default: 8; set to 4 if you do not want to set LNC=1) + GEMMA4_BATCH_SIZE (default: 1) + GEMMA4_SEQ_LEN (default: 256, kept short for first compile) + GEMMA4_DISABLE_MOE (default: 0; set to 1 for dense-only smoke) + GEMMA4_MOE_EP_DEGREE (default: 1) + GEMMA4_MOE_TP_DEGREE (default: ) +""" + +import json +import os +import sys +import time +from pathlib import Path + +import torch + +# Apply NxDI runtime patches (NKI kernel for d>128, get_last_kv_window fix). +import ndxi_patch # noqa: E402 + +ndxi_patch.apply_patch() + +from modeling_gemma4_neuron import ( # noqa: E402 + Gemma4InferenceConfig, + Gemma4NeuronConfig, + NeuronGemma4ForCausalLM, +) + + +MODEL_PATH = os.environ.get("GEMMA4_MODEL_PATH", "/home/ubuntu/gemma4-26b-a4b") +COMPILED_PATH = os.environ.get("GEMMA4_COMPILED_PATH", "/home/ubuntu/gemma4-compiled") +TP_DEGREE = int(os.environ.get("GEMMA4_TP_DEGREE", "8")) +BATCH_SIZE = int(os.environ.get("GEMMA4_BATCH_SIZE", "1")) +SEQ_LEN = int(os.environ.get("GEMMA4_SEQ_LEN", "256")) +MOE_EP_DEGREE = int(os.environ.get("GEMMA4_MOE_EP_DEGREE", "1")) +MOE_TP_DEGREE = int(os.environ.get("GEMMA4_MOE_TP_DEGREE", str(TP_DEGREE))) + + +def create_config(model_path: str) -> Gemma4InferenceConfig: + neuron_config = Gemma4NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + on_device_sampling_config=None, + torch_dtype=torch.bfloat16, + fused_qkv=False, + attn_kernel_enabled=False, + # MoE knobs (consumed by MoENeuronConfig __init__): + moe_ep_degree=MOE_EP_DEGREE, + moe_tp_degree=MOE_TP_DEGREE, + glu_mlp=True, + glu_type="glu", + # Gemma's router runs softmax in FP32. We fold this into the + # custom NeuronGemma4Router, but the underlying NxDI RouterConfig + # is still consulted by `initialize_moe_module` for typing of any + # internal router-state buffers, so set sensible values here. + router_act_fn="softmax", + router_dtype="float32", + # Gemma renormalizes top-k weights INSIDE the custom router and + # bakes per_expert_scale in there. Disable NxDI's renorm so it + # uses our pre-computed expert_affinities verbatim. + disable_normalize_top_k_affinities=True, + ) + + def load_config_fn(config_obj): + config_path = os.path.join(model_path, "config.json") + with open(config_path) as f: + config_dict = json.load(f) + for k, v in config_dict.items(): + setattr(config_obj, k, v) + + cfg = Gemma4InferenceConfig( + neuron_config=neuron_config, load_config=load_config_fn + ) + # Smoke-compile flag: set GEMMA4_DISABLE_MOE=1 to validate the rest of + # the architecture without the MoE branch (dense MLP only). MoE + # integration with NxDI moe_v2 has separate process-group bring-up + # work — see README "Known limitations". + if os.environ.get("GEMMA4_DISABLE_MOE", "0") == "1": + cfg.disable_moe_for_smoke_compile = True + return cfg + + +def main() -> int: + print("=" * 80) + print(f"Gemma-4-26B-A4B-it smoke compile") + print(f" model_path: {MODEL_PATH}") + print(f" compiled_path: {COMPILED_PATH}") + print(f" tp_degree: {TP_DEGREE}") + print(f" batch_size: {BATCH_SIZE}") + print(f" seq_len: {SEQ_LEN}") + print("=" * 80) + + if not Path(MODEL_PATH).exists(): + print(f"ERROR: model path {MODEL_PATH} does not exist", file=sys.stderr) + return 1 + + config = create_config(MODEL_PATH) + print( + f"Config loaded: hidden_size={config.hidden_size}, " + f"num_layers={config.num_hidden_layers}, " + f"num_experts={getattr(config, 'num_experts', None)}, " + f"top_k={getattr(config, 'top_k_experts', None)}" + ) + + print(f"\nCompiling to {COMPILED_PATH} ...") + t0 = time.perf_counter() + model = NeuronGemma4ForCausalLM(MODEL_PATH, config) + model.compile(COMPILED_PATH) + elapsed = time.perf_counter() - t0 + print(f"\nCompile finished in {elapsed/60:.1f} min") + + print("\nLoading compiled model ...") + model = NeuronGemma4ForCausalLM(MODEL_PATH, config) + model.load(COMPILED_PATH) + print("Smoke compile + load OK") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/contrib/models/gemma-4-26b-a4b-it/scripts/smoke_inference.py b/contrib/models/gemma-4-26b-a4b-it/scripts/smoke_inference.py new file mode 100644 index 00000000..cc3b3746 --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/scripts/smoke_inference.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +"""Smoke inference for Gemma-4-26B-A4B-it on Trainium 2. + +Goal: confirm the compiled+loaded model produces non-trivial output for a +short prompt. We do NOT validate accuracy here — that's downstream. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + cd contrib/models/gemma-4-26b-a4b-it + PYTHONPATH=src python scripts/smoke_inference.py 2>&1 | tee ~/inference.log + +Environment overrides (must match the values used by smoke_compile.py): + GEMMA4_MODEL_PATH (default: /home/ubuntu/gemma4-26b-a4b) + GEMMA4_COMPILED_PATH (default: /home/ubuntu/gemma4-compiled) + GEMMA4_TP_DEGREE (default: 8) + GEMMA4_BATCH_SIZE (default: 1) + GEMMA4_SEQ_LEN (default: 256) + GEMMA4_MAX_NEW_TOKENS (default: 8) + GEMMA4_PROMPT (default: "Hello, my name is") +""" + +import json +import os +import sys +import time +from pathlib import Path + +import torch + +import ndxi_patch # noqa: E402 + +ndxi_patch.apply_patch() + +from modeling_gemma4_neuron import ( # noqa: E402 + Gemma4InferenceConfig, + Gemma4NeuronConfig, + NeuronGemma4ForCausalLM, +) + + +MODEL_PATH = os.environ.get("GEMMA4_MODEL_PATH", "/home/ubuntu/gemma4-26b-a4b") +COMPILED_PATH = os.environ.get("GEMMA4_COMPILED_PATH", "/home/ubuntu/gemma4-compiled") +TP_DEGREE = int(os.environ.get("GEMMA4_TP_DEGREE", "8")) +BATCH_SIZE = int(os.environ.get("GEMMA4_BATCH_SIZE", "1")) +SEQ_LEN = int(os.environ.get("GEMMA4_SEQ_LEN", "256")) +MAX_NEW_TOKENS = int(os.environ.get("GEMMA4_MAX_NEW_TOKENS", "8")) +PROMPT = os.environ.get("GEMMA4_PROMPT", "Hello, my name is") +MOE_EP_DEGREE = int(os.environ.get("GEMMA4_MOE_EP_DEGREE", "1")) +MOE_TP_DEGREE = int(os.environ.get("GEMMA4_MOE_TP_DEGREE", str(TP_DEGREE))) + + +def create_config(model_path: str) -> Gemma4InferenceConfig: + neuron_config = Gemma4NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + on_device_sampling_config=None, + torch_dtype=torch.bfloat16, + fused_qkv=False, + attn_kernel_enabled=False, + moe_ep_degree=MOE_EP_DEGREE, + moe_tp_degree=MOE_TP_DEGREE, + glu_mlp=True, + glu_type="glu", + router_act_fn="softmax", + router_dtype="float32", + disable_normalize_top_k_affinities=True, + ) + + def load_config_fn(config_obj): + config_path = os.path.join(model_path, "config.json") + with open(config_path) as f: + config_dict = json.load(f) + for k, v in config_dict.items(): + setattr(config_obj, k, v) + + cfg = Gemma4InferenceConfig( + neuron_config=neuron_config, load_config=load_config_fn + ) + if os.environ.get("GEMMA4_DISABLE_MOE", "0") == "1": + cfg.disable_moe_for_smoke_compile = True + return cfg + + +def _load_tokenizer(model_path): + """Try several backends — installed transformers may break on gemma-4.""" + # 1. AutoTokenizer (preferred but may fail on gemma-4 special tokens). + try: + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained(model_path) + except Exception as e: + print(f"AutoTokenizer failed: {e}; falling back to tokenizers backend") + # 2. Raw tokenizers (HF Rust). Reads tokenizer.json directly. + from tokenizers import Tokenizer + + class _Wrapped: + def __init__(self, t, eos_ids): + self._t = t + self.eos_token_id = eos_ids + + def __call__(self, text, return_tensors=None): + enc = self._t.encode(text) + ids = torch.tensor([enc.ids], dtype=torch.long) + return {"input_ids": ids} + + def decode(self, ids, skip_special_tokens=True): + return self._t.decode(ids, skip_special_tokens=skip_special_tokens) + + t = Tokenizer.from_file(os.path.join(model_path, "tokenizer.json")) + # gemma-4 generation_config eos: 106 + 1. + return _Wrapped(t, [1, 106]) + + +def generate(model, tokenizer, prompt, max_new_tokens): + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"] + seq_len = input_ids.shape[1] + n_positions = SEQ_LEN + + if seq_len > n_positions: + raise RuntimeError(f"prompt length {seq_len} exceeds compiled seq_len {n_positions}") + + pad_len = n_positions - seq_len + input_ids_padded = torch.cat( + [input_ids, torch.zeros(1, pad_len, dtype=torch.long)], dim=1 + ) + attention_mask = torch.cat( + [ + torch.ones(1, seq_len, dtype=torch.long), + torch.zeros(1, pad_len, dtype=torch.long), + ], + dim=1, + ) + position_ids = torch.zeros(1, n_positions, dtype=torch.long) + position_ids[0, :seq_len] = torch.arange(seq_len) + + timing = {} + t0 = time.perf_counter() + with torch.no_grad(): + outputs = model( + input_ids=input_ids_padded, + attention_mask=attention_mask, + position_ids=position_ids, + ) + timing["ttft_ms"] = (time.perf_counter() - t0) * 1000 + + if hasattr(outputs, "logits") and outputs.logits is not None: + logits = outputs.logits + next_token_logits = logits[:, -1, :] if logits.dim() == 3 else logits + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) + elif hasattr(outputs, "tokens") and outputs.tokens is not None: + next_token = outputs.tokens[:, -1:] + else: + raise RuntimeError(f"Unexpected output type: {type(outputs)}") + + generated = [int(next_token.item())] + cur_pos = seq_len + + t_gen = time.perf_counter() + for _ in range(max_new_tokens - 1): + attention_mask[0, cur_pos] = 1 + with torch.no_grad(): + outputs = model( + input_ids=next_token, + attention_mask=attention_mask, + position_ids=torch.tensor([[cur_pos]]), + ) + cur_pos += 1 + if hasattr(outputs, "logits") and outputs.logits is not None: + logits = outputs.logits + next_token_logits = logits[:, -1, :] if logits.dim() == 3 else logits + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) + elif hasattr(outputs, "tokens") and outputs.tokens is not None: + next_token = outputs.tokens[:, -1:] + else: + break + generated.append(int(next_token.item())) + eos = tokenizer.eos_token_id + if isinstance(eos, list): + if next_token.item() in eos: + break + elif next_token.item() == eos: + break + + t_end = time.perf_counter() + n_decode = len(generated) - 1 + if n_decode > 0: + timing["tpot_ms"] = (t_end - t_gen) / n_decode * 1000 + timing["throughput_tps"] = n_decode / (t_end - t_gen) + timing["total_tokens"] = len(generated) + return generated, tokenizer.decode(generated, skip_special_tokens=True), timing + + +def main() -> int: + print("=" * 80) + print("Gemma-4-26B-A4B-it smoke inference") + print(f" model_path: {MODEL_PATH}") + print(f" compiled_path: {COMPILED_PATH}") + print(f" tp_degree: {TP_DEGREE}") + print(f" seq_len: {SEQ_LEN}") + print(f" prompt: {PROMPT!r}") + print("=" * 80) + + if not Path(COMPILED_PATH).exists(): + print(f"ERROR: compiled path {COMPILED_PATH} does not exist", file=sys.stderr) + return 1 + + config = create_config(MODEL_PATH) + print("Loading compiled model ...") + t0 = time.perf_counter() + model = NeuronGemma4ForCausalLM(MODEL_PATH, config) + model.load(COMPILED_PATH) + print(f"Load took {time.perf_counter() - t0:.1f}s") + + print("Loading tokenizer ...") + tokenizer = _load_tokenizer(MODEL_PATH) + + print(f"\nGenerating {MAX_NEW_TOKENS} tokens ...") + tokens, text, timing = generate(model, tokenizer, PROMPT, MAX_NEW_TOKENS) + print(f"\nGenerated tokens: {tokens}") + print(f"Decoded: {text!r}") + print(f"Timing: {timing}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/contrib/models/gemma-4-26b-a4b-it/src/__init__.py b/contrib/models/gemma-4-26b-a4b-it/src/__init__.py new file mode 100644 index 00000000..5f051549 --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/src/__init__.py @@ -0,0 +1,18 @@ +# NeuronX Distributed Inference port of google/gemma-4-26B-A4B-it. +# +# Public surface mirrors PR #106 (gemma-4-31b-it) but text-only and +# MoE-aware. See README.md for status and usage. + +from .configuration_gemma4_neuron import Gemma4TextConfig # noqa: F401 +from .modeling_gemma4_neuron import ( # noqa: F401 + Gemma4InferenceConfig, + Gemma4NeuronConfig, + NeuronGemma4ForCausalLM, +) + +__all__ = [ + "Gemma4TextConfig", + "Gemma4InferenceConfig", + "Gemma4NeuronConfig", + "NeuronGemma4ForCausalLM", +] diff --git a/contrib/models/gemma-4-26b-a4b-it/src/configuration_gemma4_neuron.py b/contrib/models/gemma-4-26b-a4b-it/src/configuration_gemma4_neuron.py new file mode 100644 index 00000000..709240be --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/src/configuration_gemma4_neuron.py @@ -0,0 +1,111 @@ +# Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# ========================================================================== +# +# Gemma-4-26B-A4B-it configuration shim. +# +# Round 2 simplification: the actual config classes +# (`Gemma4InferenceConfig`, `Gemma4NeuronConfig`) live in +# `modeling_gemma4_neuron.py` (PR #106 pattern). This file used to host +# factory functions that delayed NxDI imports for laptop-side parsing; that +# was removed because every consumer of these classes already imports from +# `modeling_gemma4_neuron` (which itself requires NxDI). Keeping the +# factories was dead code that just hid bugs. +# +# What stays here: a small HF-side compatibility shim +# (`Gemma4TextConfig`) that lets static tools parse this package without the +# upstream gemma4 transformers branch installed. It is unused at runtime +# when the real `transformers.models.gemma4` is importable. + +from __future__ import annotations + +from typing import Any, Optional + + +try: + from transformers.models.gemma4.configuration_gemma4 import ( # type: ignore[import-not-found] + Gemma4TextConfig as _HFGemma4TextConfig, + ) +except Exception: # pragma: no cover - shim path + _HFGemma4TextConfig = None # noqa: N816 + + +class _Gemma4TextConfigShim: + """Mirror the fields of HF `Gemma4TextConfig` we need for static parsing.""" + + model_type = "gemma4_text" + + # 26B-A4B values from the actual HF config.json (see traces/round2_diff.md + # for the corrections vs round-1 guesses). + vocab_size: int = 262_144 + hidden_size: int = 2816 + intermediate_size: int = 2112 + num_hidden_layers: int = 30 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + head_dim: int = 256 + hidden_activation: str = "gelu_pytorch_tanh" + max_position_embeddings: int = 262_144 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: Optional[int] = 0 + eos_token_id: Any = 1 + bos_token_id: Optional[int] = 2 + tie_word_embeddings: bool = True + rope_parameters: Optional[dict] = None + attention_bias: bool = False + attention_dropout: float = 0.0 + sliding_window: int = 1024 + layer_types: Optional[list] = None + final_logit_softcapping: Optional[float] = 30.0 + use_bidirectional_attention: Optional[str] = None + # 26B-A4B: PLE disabled (hidden_size_per_layer_input == 0). + vocab_size_per_layer_input: int = 262_144 + hidden_size_per_layer_input: int = 0 + num_global_key_value_heads: Optional[int] = 2 + global_head_dim: int = 512 + attention_k_eq_v: bool = True + num_kv_shared_layers: int = 0 + # 26B-A4B has MoE enabled. + enable_moe_block: bool = True + use_double_wide_mlp: bool = False + num_experts: Optional[int] = 128 + top_k_experts: Optional[int] = 8 + moe_intermediate_size: Optional[int] = 704 + + def __init__(self, **kwargs: Any) -> None: + for k, v in kwargs.items(): + setattr(self, k, v) + self._post_init() + + def _post_init(self) -> None: + if self.layer_types is None: + # 5 sliding : 1 full pattern, last layer == full (per HF post-init). + sliding_window_pattern = 6 + self.layer_types = [ + "sliding_attention" if (i + 1) % sliding_window_pattern else "full_attention" + for i in range(self.num_hidden_layers) + ] + self.layer_types[-1] = "full_attention" + + if self.rope_parameters is None: + self.rope_parameters = { + "sliding_attention": { + "rope_type": "default", + "rope_theta": 10_000.0, + }, + "full_attention": { + "rope_type": "proportional", + "partial_rotary_factor": 0.25, + "rope_theta": 1_000_000.0, + }, + } + + +Gemma4TextConfig = _HFGemma4TextConfig if _HFGemma4TextConfig is not None else _Gemma4TextConfigShim + + +__all__ = ["Gemma4TextConfig"] diff --git a/contrib/models/gemma-4-26b-a4b-it/src/modeling_gemma4_neuron.py b/contrib/models/gemma-4-26b-a4b-it/src/modeling_gemma4_neuron.py new file mode 100644 index 00000000..c90d8922 --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/src/modeling_gemma4_neuron.py @@ -0,0 +1,1356 @@ +# coding=utf-8 +# Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# ========================================================================== +# NeuronX Distributed Inference port of google/gemma-4-26B-A4B-it. +# ========================================================================== +# +# Round 2 — heavily borrowed from Jim Burtoft's PR #106 (gemma-4-31B-IT) for +# attention, KV cache, softcapping, and weight conversion. The MoE block, +# router, and decoder-layer MoE branch are 26B-A4B-specific (the 31B model +# is dense). NKI flash attention kernels (`nki_flash_attn_d256_swa.py` and +# `nki_flash_attn_large_d.py`) are taken verbatim from PR #106; head_dim +# values match (256 sliding / 512 full) so they work as-is. +# +# Differences from PR #106 worth knowing about: +# * 26B-A4B has `enable_moe_block=True`, `num_experts=128`, `top_k=8`. +# Each decoder layer runs the dense MLP and the MoE block in parallel +# (HF source lines 1429-1441), then sums their outputs. +# * `hidden_size=2816` (vs 31B's 5376), `num_attention_heads=16`, +# `num_key_value_heads=8` for sliding, `num_global_key_value_heads=2`. +# * `final_logit_softcapping=30.0` (same as 31B-IT). +# * `hidden_size_per_layer_input=0` — no per-layer-embedding (PLE) on +# 26B-A4B, so the round-1 PLE code is dropped. +# +# To use the NKI kernels at runtime, callers must invoke +# `import ndxi_patch; ndxi_patch.apply_patch()` before +# constructing `NeuronGemma4ForCausalLM`. + +from __future__ import annotations + +import copy +import math +from types import SimpleNamespace +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers import mappings as _nxd_parallel_mappings + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, + NeuronConfig, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + FlashAttentionStrategy, + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import ( + RotaryEmbedding, + apply_rotary_pos_emb, +) +from neuronx_distributed_inference.modules.attention.gqa import ( + determine_sharding_strategy, + get_shardable_head_counts, +) +from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import ( + KVCacheManager, +) +from neuronx_distributed_inference.modules.kvcache.utils import get_kv_shapes + +# MoE module from NxDI v2. +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + +# NKI flash attention kernel for head_dim=256 SWA layers (sliding window). +try: + try: + from .nki_flash_attn_d256_swa import flash_attn_d256_swa as _nki_flash_attn_d256_swa # type: ignore[import-not-found] + except ImportError: + # Fallback for top-level (PYTHONPATH=src) layout. + from nki_flash_attn_d256_swa import flash_attn_d256_swa as _nki_flash_attn_d256_swa # type: ignore[import-not-found] + + _HAS_NKI_SWA_KERNEL = True +except Exception: # pragma: no cover - kernel optional at import time + _HAS_NKI_SWA_KERNEL = False + + +# ==================================================================================== +# Normalization (PR #106 pattern: Gemma4RMSNorm with weight, Gemma4VNorm without) +# ==================================================================================== + + +class Gemma4RMSNorm(nn.Module): + """Standard Gemma4 RMSNorm: normed * weight (weight init to ones).""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = self._norm(x.float()) + output = output * self.weight.float() + return output.type_as(x) + + +class Gemma4VNorm(nn.Module): + """Gemma4 v_norm: RMSNorm without learnable scale (with_scale=False).""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = x.float() * torch.rsqrt( + x.float().pow(2).mean(-1, keepdim=True) + self.eps + ) + return output.type_as(x) + + +def get_rmsnorm_cls(): + """Single source of truth for the outer norm class (matches PR #106).""" + return Gemma4RMSNorm + + +# ==================================================================================== +# Embeddings + softcapped LM head (verbatim from PR #106) +# ==================================================================================== + + +class SoftcappedLMHead(nn.Module): + """Wrap lm_head and apply final_logit_softcapping: cap * tanh(x / cap).""" + + def __init__(self, linear: nn.Module, cap: float): + super().__init__() + self.linear = linear + self.cap = cap + + def forward(self, x): + logits = self.linear(x) + logits = logits.float() + return self.cap * torch.tanh(logits / self.cap) + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.linear, name) + + +class Gemma4ScaledEmbedding(nn.Module): + """Token embedding with sqrt(hidden_size) scaling (per gemma4 source).""" + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + dtype: torch.dtype, + shard_across_embedding: bool = True, + pad: bool = True, + sequence_parallel_enabled: bool = False, + ): + super().__init__() + self.embed_scale = embedding_dim**0.5 + self.embedding = ParallelEmbedding( + num_embeddings, + embedding_dim, + padding_idx, + dtype=dtype, + shard_across_embedding=shard_across_embedding, + pad=pad, + sequence_parallel_enabled=sequence_parallel_enabled, + ) + + def forward(self, input_ids: torch.Tensor): + return self.embedding(input_ids) * self.embed_scale + + +# ==================================================================================== +# Configuration (PR #106 pattern, extended for MoE attributes) +# ==================================================================================== + + +class Gemma4NeuronConfig(MoENeuronConfig): + """NeuronConfig hard-pinning the gemma4 attention class. + + Extends `MoENeuronConfig` (not plain `NeuronConfig`) so the MoE-specific + attributes that `initialize_moe_module` reads off `config.neuron_config` + -- `router_config`, `blockwise_matmul_config`, `moe_ep_degree`, + `moe_tp_degree`, `glu_mlp`, `glu_type`, `normalize_top_k_affinities`, + `early_expert_affinity_modulation`, `is_prefill_stage`, etc. -- exist + even on dense smoke-compile runs (default values are harmless when MoE + is disabled). + """ + + def __init__(self, **kwargs): + # Gemma-4-26B-A4B keeps the routed top-k weights AS GIVEN by the router + # (the router itself already renormalizes to sum=1 and applies the + # per-expert scale). We do NOT want NxDI to renormalize again, so set + # the disable flag (NeuronGemma4MoEBlock pre-builds the (T,E) + # expert_affinities tensor with zeros outside top-k). + kwargs.setdefault("disable_normalize_top_k_affinities", True) + super().__init__(**kwargs) + # attn_cls is set per-layer in get_updated_configs(); this default is + # used for the framework-level introspection only. + self.attn_cls = NeuronGemma4Attention + + +class Gemma4InferenceConfig(InferenceConfig): + """Inference config that pulls fields from HF gemma4 config.json (text_config).""" + + def __init__( + self, + neuron_config: NeuronConfig, + fused_spec_config=None, + load_config=None, + **kwargs, + ): + self.neuron_config = neuron_config + self.fused_spec_config = fused_spec_config + + if load_config is not None: + load_config(self) + else: + self.load_config() + + # Gemma4 nests text params under text_config. + text_config = getattr(self, "text_config", None) + if text_config is not None: + if isinstance(text_config, dict): + self.text_config = SimpleNamespace(**text_config) + text_config = self.text_config + text_attrs = [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "head_dim", + "intermediate_size", + "vocab_size", + "max_position_embeddings", + "rms_norm_eps", + "sliding_window", + "hidden_activation", + # Gemma4-specific + "global_head_dim", + "num_global_key_value_heads", + "attention_k_eq_v", + "final_logit_softcapping", + "layer_types", + "rope_parameters", + # MoE-specific (26B-A4B) + "enable_moe_block", + "num_experts", + "top_k_experts", + "moe_intermediate_size", + # PLE / KV-share (0 on 26B-A4B but kept for forward compat) + "hidden_size_per_layer_input", + "vocab_size_per_layer_input", + "num_kv_shared_layers", + "use_double_wide_mlp", + "tie_word_embeddings", + "pad_token_id", + ] + for attr in text_attrs: + if isinstance(text_config, dict): + if attr in text_config: + setattr(self, attr, text_config[attr]) + elif hasattr(text_config, attr): + setattr(self, attr, getattr(text_config, attr)) + + # PretrainedConfig defaults that SimpleNamespace conversion drops. + text_config = getattr(self, "text_config", None) + if text_config is not None: + for attr, default in [ + ("output_attentions", False), + ("output_hidden_states", False), + ("use_return_dict", True), + ]: + if not hasattr(text_config, attr): + setattr(text_config, attr, default) + for attr, default in [ + ("output_attentions", False), + ("output_hidden_states", False), + ("use_return_dict", True), + ]: + if not hasattr(self, attr): + setattr(self, attr, default) + + if not hasattr(self, "pad_token_id"): + self.pad_token_id = 0 + if not hasattr(self, "tie_word_embeddings"): + self.tie_word_embeddings = True + if not hasattr(self, "attention_bias"): + self.attention_bias = False + + if hasattr(self, "hidden_activation") and not hasattr(self, "hidden_act"): + self.hidden_act = self.hidden_activation + + self.add_derived_config() + self.validate_config() + + def add_derived_config(self): + self.num_cores_per_group = 1 + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "head_dim", + "vocab_size", + "max_position_embeddings", + "rms_norm_eps", + "intermediate_size", + "global_head_dim", + "num_global_key_value_heads", + "layer_types", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[Gemma4NeuronConfig]: + return Gemma4NeuronConfig + + +def get_updated_configs(config: Gemma4InferenceConfig): + """Per-layer configs for heterogeneous SWA/global layers (PR #106 pattern).""" + updated_configs = [] + for i in range(config.num_hidden_layers): + layer_config = copy.deepcopy(config) + layer_type = config.layer_types[i] + + # MoE config aliases that NxDI's `initialize_moe_module` reads off the + # layer config. We do NOT overwrite `intermediate_size` here (the + # dense MLP needs it); the MoE block uses `_moe_config` (set later) + # which has the moe_intermediate_size. + if getattr(layer_config, "enable_moe_block", False): + layer_config.num_local_experts = layer_config.num_experts + layer_config.num_experts_per_tok = layer_config.top_k_experts + + if layer_type == "sliding_attention": + layer_config.sliding_window = config.sliding_window + layer_config._layer_head_dim = config.head_dim + layer_config._layer_num_kv_heads = config.num_key_value_heads + layer_config._layer_is_sliding = True + layer_config._layer_k_eq_v = False + rope_params = config.rope_parameters.get("sliding_attention", {}) + layer_config._layer_rope_theta = rope_params.get("rope_theta", 10000.0) + layer_config._layer_partial_rotary_factor = 1.0 + else: + layer_config.sliding_window = None + layer_config._layer_head_dim = config.global_head_dim + layer_config._layer_num_kv_heads = config.num_global_key_value_heads + layer_config._layer_is_sliding = False + layer_config._layer_k_eq_v = getattr(config, "attention_k_eq_v", False) + rope_params = config.rope_parameters.get("full_attention", {}) + layer_config._layer_rope_theta = rope_params.get("rope_theta", 1000000.0) + layer_config._layer_partial_rotary_factor = rope_params.get( + "partial_rotary_factor", 0.25 + ) + + updated_configs.append(layer_config) + return updated_configs + + +# ==================================================================================== +# Attention (PR #106 verbatim, head_dim values match for 26B-A4B) +# ==================================================================================== + + +class NeuronGemma4Attention(NeuronAttentionBase): + """Gemma4 attention with per-layer head_dim/kv_heads, partial RoPE, v_norm. + + Borrowed wholesale from PR #106 (Jim Burtoft, gemma-4-31B-IT). 26B-A4B + shares all the head dimensions (256 sliding / 512 global) so this works + unchanged. The only per-config differences (kv head counts, rope theta) + are read from the per-layer config dict produced by `get_updated_configs`. + """ + + def __init__(self, config: Gemma4InferenceConfig): + head_dim = config._layer_head_dim + num_kv_heads = config._layer_num_kv_heads + is_sliding = config._layer_is_sliding + rope_theta = config._layer_rope_theta + partial_rotary_factor = config._layer_partial_rotary_factor + + # Partial RoPE: rotate first head_dim*factor dims, leave the rest alone. + rotary_dim = int(head_dim * partial_rotary_factor) + rotary_dim = rotary_dim - (rotary_dim % 2) + + rotary_emb = RotaryEmbedding( + dim=rotary_dim, + max_position_embeddings=config.max_position_embeddings, + base=rope_theta, + ) + + # PR #106 Discovery #27: pass sliding_window=None to base for ALL layers + # to avoid OOB in get_last_kv_window when bucket_size < sliding_window. + # Windowed masking is applied at the decoder-layer level via local_mask. + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=num_kv_heads, + head_dim=head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=config.rms_norm_eps, + use_qk_norm=False, + sliding_window=None, + post_transpose_layernorm=True, + ) + + # QK norms: gemma4 RMSNorm with learned weight (initialized to 1). + self.q_layernorm = get_rmsnorm_cls()(dim=head_dim, eps=config.rms_norm_eps) + self.k_layernorm = get_rmsnorm_cls()(dim=head_dim, eps=config.rms_norm_eps) + + # V norm: RMSNorm without learnable scale. + self.v_norm = Gemma4VNorm(dim=head_dim, eps=config.rms_norm_eps) + + self._is_sliding = is_sliding + self._k_eq_v = config._layer_k_eq_v + self._head_dim = head_dim + self._rotary_dim = rotary_dim + self._partial_rotary_factor = partial_rotary_factor + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + if self.rotary_emb is None: + return Q, K, cos_cache, sin_cache + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + if self._rotary_dim == self._head_dim: + Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache) + else: + q_rot = Q[..., : self._rotary_dim] + q_pass = Q[..., self._rotary_dim :] + k_rot = K[..., : self._rotary_dim] + k_pass = K[..., self._rotary_dim :] + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos_cache, sin_cache) + Q = torch.cat([q_rot, q_pass], dim=-1) + K = torch.cat([k_rot, k_pass], dim=-1) + return Q, K, cos_cache, sin_cache + + def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask): + """Use NKI d=256 SWA kernel for sliding layers when available.""" + if ( + _HAS_NKI_SWA_KERNEL + and self._is_sliding + and self._head_dim == 256 + and q_len >= 128 + ): + q_kernel = Q.to(self.torch_dtype) + k_kernel = K.to(self.torch_dtype) + v_kernel = V.to(self.torch_dtype) + + n_kv_heads = K.shape[1] + n_q_heads = Q.shape[1] + q_h_per_kv = n_q_heads // n_kv_heads + window_size = 1024 + + out_parts = [] + for b in range(bsz): + for kv_h in range(n_kv_heads): + q_slice = q_kernel[ + b : b + 1, kv_h * q_h_per_kv : (kv_h + 1) * q_h_per_kv, :, : + ] + k_slice = k_kernel[b : b + 1, kv_h : kv_h + 1, :, :] + v_slice = v_kernel[b : b + 1, kv_h : kv_h + 1, :, :] + o_part = _nki_flash_attn_d256_swa( + q_slice, + k_slice, + v_slice, + q_h_per_k_h=q_h_per_kv, + n_kv_heads=1, + seqlen_q=q_len, + seqlen_kv=q_len, + window_size=window_size, + ) + out_parts.append(o_part) + attn_output = torch.cat(out_parts, dim=1) + if bsz > 1: + attn_output = attn_output.reshape(bsz, n_q_heads, q_len, self._head_dim) + return attn_output, FlashAttentionStrategy.NONE + + return super().perform_prefill(Q, K, V, q_len, bsz, attention_mask) + + def prep_qkv_tensors( + self, + position_ids, + hidden_states, + past_key_value=None, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + skip_rope=False, + residual=None, + use_polar_compatible_rope=False, + ): + Q, K, V, cos_cache, sin_cache, residual = super().prep_qkv_tensors( + position_ids=position_ids, + hidden_states=hidden_states, + past_key_value=past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + skip_rope=skip_rope, + residual=residual, + use_polar_compatible_rope=use_polar_compatible_rope, + ) + # Apply v_norm on BHSD-laid V (last dim == head_dim). + V = self.v_norm(V) + return Q, K, V, cos_cache, sin_cache, residual + + +# ==================================================================================== +# MLP (dense feed-forward) +# ==================================================================================== + + +class NeuronGemma4MLP(nn.Module): + """Dense SwiGLU MLP with `gelu_pytorch_tanh` activation.""" + + def __init__(self, config: Gemma4InferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + dtype = config.neuron_config.torch_dtype + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + pad=True, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + pad=True, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + ) + self.act_fn = nn.GELU(approximate="tanh") + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)), None + + +# ==================================================================================== +# Router + MoE block (26B-A4B specific — PR #106 has no MoE) +# ==================================================================================== + + +class NeuronGemma4Router(nn.Module): + """Top-K router with gemma4's `scale` and `per_expert_scale` learned tensors. + + Mirrors HF `Gemma4TextRouter` (modeling_gemma4.py:1334) and exposes the + contract that NxDI's `MoE` wrapper expects from its `router` member: + + forward(hidden_states) -> (router_logits, expert_affinities, expert_index) + + where + router_logits : (T, E) raw projection (used for aux losses; we + return it for compatibility, never used + at inference) + expert_affinities : (T, E) sparse tensor with the FINAL post-softmax + / post-renormalize / post-per-expert-scale + weights at top-k indices, zero elsewhere. + With `normalize_top_k_affinities=False` + on the MoE config, NxDI's expert dispatch + uses these values directly. + expert_index : (T, K) top-k expert indices. + + Routing math is FP32 for numerical stability across 128 experts. + + Replicated across TP (no parallel layer): every rank computes the same + routing decisions, so no all-reduce is needed for the indices. + """ + + REQUIRED_ATTRS = ("num_experts", "top_k", "hidden_size", "sequence_parallel_enabled", "sequence_dimension") + + def __init__(self, config: Gemma4InferenceConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.scalar_root_size = self.hidden_size ** -0.5 + self.eps = config.rms_norm_eps + self.top_k = config.top_k_experts + self.num_experts = config.num_experts + + # NxDI's `MoE.__init__` cross-checks (`router.num_experts`, + # `router.top_k`, `router.hidden_size`) against the expert_mlps + # config -- the names above already match. + + # NxDI also reads `router.sequence_parallel_enabled` and + # `router.sequence_dimension` to decide whether to gather + # hidden_states before / after routing. We replicate the router + # weights across TP, so SP-on-router is False. + self.sequence_parallel_enabled = False + self.sequence_dimension = 1 + + # No-scale RMSNorm (gemma4 source line 1342: with_scale=False). + self.norm = Gemma4VNorm(self.hidden_size, eps=self.eps) + # Replicated across TP — every rank must reach the same routing. + self.proj = nn.Linear( + self.hidden_size, self.num_experts, bias=False, dtype=torch.float32 + ) + # Learned scale parameters (gemma4 source 1344-1345). + self.scale = nn.Parameter(torch.ones(self.hidden_size, dtype=torch.float32)) + self.per_expert_scale = nn.Parameter( + torch.ones(self.num_experts, dtype=torch.float32) + ) + + def forward(self, hidden_states: torch.Tensor): + # Accept either (B, S, H) or (T, H). NxDI's MoE wrapper passes the + # full / SP-gathered hidden states; we flatten to (T, H) for routing. + original_shape = hidden_states.shape + h = hidden_states.float().reshape(-1, original_shape[-1]) # (T, H) + + h = self.norm(h) + h = h * self.scale * self.scalar_root_size + + router_logits = self.proj(h) # (T, E) FP32 + probs = F.softmax(router_logits, dim=-1) + + top_k_weights, top_k_index = torch.topk(probs, k=self.top_k, dim=-1) + # Per-token re-normalisation (gemma4 source line 1362). + top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) + # Apply per-expert scale (gemma4 source line 1365). + top_k_weights = top_k_weights * self.per_expert_scale[top_k_index] + + # Build sparse (T, E) expert_affinities: zero everywhere except the + # top-k positions, which carry the post-renormalize, post-per-expert + # scaled weights. With MoE config `normalize_top_k_affinities=False`, + # the downstream ExpertMLPsV2 will use these values verbatim. + expert_affinities = torch.zeros_like(probs) + expert_affinities = expert_affinities.scatter( + 1, top_k_index, top_k_weights.to(expert_affinities.dtype) + ) + + # Cast to hidden_states dtype so downstream matmuls stay in bf16. + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + expert_index = top_k_index.detach().to(dtype=torch.long) + # router_logits stays in FP32; NxDI doesn't actually use it at + # inference (return_router_logits is False by default). + return router_logits, expert_affinities, expert_index + + +class NeuronGemma4MoEBlock(nn.Module): + """Gemma-4 MoE block built on top of NxDI's `initialize_moe_module`. + + Why we don't just call NxDI's wrapper `MoE.forward(hidden_states)`: + `MoE.forward` calls the router and the expert_mlps with the SAME + `hidden_states`, but Gemma-4's HF source feeds the router the RAW + pre-MLP residual (line 1434) while the experts get the + `pre_feedforward_layernorm_2(residual)` (line 1435-1436). The two + inputs differ by an RMSNorm with a learned scale, so collapsing them + is numerically incorrect. + + We therefore: + 1. Use `initialize_moe_module(config=...)` to construct router + + expert_mlps with all the correct process-group / dtype wiring. + 2. Replace the auto-built `RouterTopK` with `NeuronGemma4Router` + (custom math: no-scale norm, scalar_root_size, learned scale + + per_expert_scale, top-k renorm). + 3. Re-implement the wrapper's forward path — call router on + `router_input`, expert_mlps on `expert_input`, then do the + delayed all-reduce that `MoE.forward` would have done. + + No shared experts inside this block: Gemma-4's "shared" branch is the + dense MLP that runs in PARALLEL with this block at the decoder-layer + level (HF source 1427+1441 — `mlp(x) + experts(x)`), so we set + `n_shared_experts = 0` and skip NxDI's `_apply_shared_experts` path. + + Sequence parallelism / token shuffle / EP > 1 are unsupported by this + wrapper (we follow the simplest non-SP, non-EP path). If those are + needed later, mirror the corresponding branches from + `neuronx_distributed.modules.moe.model.MoE._forward_compute_bound`. + """ + + def __init__(self, config: Gemma4InferenceConfig): + super().__init__() + self.config = config + # `initialize_moe_module` reads `n_shared_experts` -- Llama4 sets =1, + # we set =0 because Gemma's "shared" branch is the dense MLP that + # lives OUTSIDE this module. + if not hasattr(config, "n_shared_experts"): + config.n_shared_experts = 0 + + # Build the underlying NxDI MoE wrapper to leverage its router + + # expert_mlps construction (process groups, dtype, padding, etc.). + self.moe = initialize_moe_module(config=config) + # Replace the wrapper's RouterTopK with Gemma's custom router. Safe: + # `MoE.__init__` already validated num_experts/top_k/hidden_size + # cross-consistency, and we keep those identical on our router. + self.moe.router = NeuronGemma4Router(config) + + # Pull tensor-parallel group from the wrapper for the explicit + # delayed all-reduce in our forward. + self._tp_group = self.moe.tensor_parallel_group + self._ep_enabled = self.moe.ep_enabled + + def forward( + self, + router_input: torch.Tensor, + expert_input: torch.Tensor, + ) -> torch.Tensor: + """Compute Gemma-4 MoE output. + + Arguments: + router_input: pre-norm `residual`, shape (B, S, H) or (T, H). + The router applies its own (no-scale) RMSNorm + internally, matching HF source line 1434. + expert_input: `pre_feedforward_layernorm_2(residual)`, same + leading layout as `router_input`. Fed to the + expert MLPs (HF source line 1436). + + Returns: + output: same leading shape as `expert_input`, after expert + dispatch + post-TP all-reduce. Caller is responsible + for any further normalization (HF's + `post_feedforward_layernorm_2`) and summation with + the dense-MLP branch. + """ + # Router consumes raw residual (it flattens to (T, H) internally). + # Returns: router_logits (T,E), expert_affinities (T,E), + # expert_index (T, top_k). + _router_logits, expert_affinities, expert_index = self.moe.router(router_input) + + # Mirror MoE.forward: gradient bookkeeping all-reduce on expert + # affinities (no-op forward in pure inference, but keeps the + # autograd graph consistent if we're ever traced for backward). + if not self._ep_enabled: + expert_affinities = _nxd_parallel_mappings.copy_to_tensor_model_parallel_region( + expert_affinities, process_group=self._tp_group + ) + + # Flatten experts input to (T, H). seq_len comes from the + # sequence_dimension axis of the un-flattened input. Caller is + # expected to pass (B, S, H); if a 2D (T, H) tensor lands here we + # fall back to seq_len = T (seq_len affects only path selection + # heuristics inside ExpertMLPsV2, not correctness for full-capacity + # routing). + expert_shape = expert_input.shape + if expert_input.dim() >= self.moe.sequence_dimension + 1: + seq_len = expert_shape[self.moe.sequence_dimension] + else: + seq_len = expert_shape[0] + expert_input_flat = expert_input.reshape(-1, expert_shape[-1]) + + output = self.moe.expert_mlps( + hidden_states=expert_input_flat, + expert_affinities=expert_affinities, + expert_index=expert_index, + seq_len=seq_len, + ) + + # Reshape back and do delayed all-reduce (TP) since expert_mlps + # leaves the output in tensor-parallel un-reduced form. + output = output.view(expert_shape) + if not self._ep_enabled: + output = _nxd_parallel_mappings.reduce_from_tensor_model_parallel_region( + output, process_group=self._tp_group + ) + return output + + +# ==================================================================================== +# Decoder layer (dense MLP + parallel MoE branch) +# ==================================================================================== + + +class NeuronGemma4DecoderLayer(nn.Module): + """Gemma4 decoder layer with optional parallel MoE branch and layer_scalar.""" + + def __init__(self, config: Gemma4InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.is_sliding_window_attention = config._layer_is_sliding + + self.self_attn = NeuronGemma4Attention(config) + self.mlp = NeuronGemma4MLP(config) + + norm_cls = get_rmsnorm_cls() + self.input_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps) + + # Per-layer learned scaling factor (must be Parameter, not buffer, so + # NxDI's weight loader populates it from the checkpoint). + self.layer_scalar = nn.Parameter(torch.ones(1), requires_grad=False) + + # MoE branch can be disabled for smoke compile (set + # `disable_moe_for_smoke_compile=True` on the InferenceConfig) to + # validate the rest of the architecture without the routed experts. + # When enabled, the dense MLP and the MoE block run in PARALLEL and + # their outputs are summed (HF source 1427-1441). + self.enable_moe_block = bool(getattr(config, "enable_moe_block", False)) and not bool( + getattr(config, "disable_moe_for_smoke_compile", False) + ) + if self.enable_moe_block: + # Build a separate config view for MoE. Two things must change + # vs the per-layer attention config: + # 1. `intermediate_size` -> `moe_intermediate_size` (704 not + # 2112). NxDI's `ExpertMLPsV2` reads `config.intermediate_size` + # for the expert intermediate dim. + # 2. `num_local_experts` / `num_experts_per_tok` aliases for + # Gemma's `num_experts` / `top_k_experts` (already set at + # the per-layer level by `get_updated_configs`, but we set + # again for explicitness on the deepcopy). + # 3. `n_shared_experts = 0` — the dense MLP plays the role of + # "shared expert" but lives OUTSIDE the MoE block (parallel + # branch summed with experts output, per HF source). + moe_config = copy.deepcopy(config) + moe_config.intermediate_size = config.moe_intermediate_size + moe_config.num_local_experts = config.num_experts + moe_config.num_experts_per_tok = config.top_k_experts + moe_config.n_shared_experts = 0 + # ExpertMLPsV2 raises `Unknown activation: gelu_pytorch_tanh`. + # Alias to plain `gelu` — tanh-approximation diff is <1e-4 and not + # load-bearing for routing/topk correctness. Flag for accuracy + # validation pass. + hf_act = getattr(config, "hidden_activation", "gelu_pytorch_tanh") + moe_config.hidden_act = "gelu" if hf_act == "gelu_pytorch_tanh" else hf_act + + self.moe_block = NeuronGemma4MoEBlock(moe_config) + self.post_feedforward_layernorm_1 = norm_cls(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm_2 = norm_cls(config.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm_2 = norm_cls(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + # Heterogeneous RoPE: drop cached cos/sin from the previous layer. + kwargs.pop("cos_cache", None) + kwargs.pop("sin_cache", None) + + # SWA layers use local_mask; global layers use attention_mask. + local_mask = kwargs.pop("local_mask", None) + mask = ( + local_mask + if (self.is_sliding_window_attention and local_mask is not None) + else attention_mask + ) + + # Attention block + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=mask, + position_ids=position_ids, + past_key_value=past_key_value, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Feed-forward block + residual = hidden_states + hidden_states_pre = self.pre_feedforward_layernorm(hidden_states) + hidden_states_dense = self.mlp(hidden_states_pre)[0] + + if self.enable_moe_block: + hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states_dense) + + # HF source 1432-1438: router consumes the RAW residual, while + # the experts consume `pre_feedforward_layernorm_2(residual)`. + # The two inputs differ by a learned-scale RMSNorm and cannot + # be collapsed, so we pass them separately to the MoE block + # (which bypasses NxDI's MoE wrapper to honor this split). + expert_input = self.pre_feedforward_layernorm_2(residual) + hidden_states_2 = self.moe_block( + router_input=residual, + expert_input=expert_input, + ) + hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) + + hidden_states = hidden_states_1 + hidden_states_2 + else: + hidden_states = hidden_states_dense + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Per-layer scalar + hidden_states = hidden_states * self.layer_scalar + + return (hidden_states, present_key_value, cos_cache, sin_cache, None) + + +# ==================================================================================== +# KV cache manager (PR #106 verbatim — handles heterogeneous SWA/global shapes) +# ==================================================================================== + + +class Gemma4KVCacheManager(KVCacheManager): + """KV cache manager with per-layer heterogeneous shapes. + + 26B-A4B layer kv configs (per rank, after TP sharding): + - SWA layers: num_kv_heads=8 / TP, head_dim=256 + - Global layers: num_kv_heads=2 / TP, head_dim=512 + """ + + def __init__( + self, + config, + layer_kv_configs, + global_rank=None, + attention_chunk_size=None, + sliding_window=None, + windowed_context_encoding_size=None, + layer_to_cache_size_mapping=None, + ): + self._layer_kv_configs = layer_kv_configs + + if layer_to_cache_size_mapping is None: + max_len = config.neuron_config.max_length + layer_to_cache_size_mapping = [max_len] * len(layer_kv_configs) + + max_kv_heads = max(c[0] for c in layer_kv_configs) + super().__init__( + config, + num_kv_head=max_kv_heads, + global_rank=global_rank, + attention_chunk_size=attention_chunk_size, + sliding_window=sliding_window, + windowed_context_encoding_size=windowed_context_encoding_size, + layer_to_cache_size_mapping=layer_to_cache_size_mapping, + ) + + def _init_kv_shape(self, config, layer_to_cache_size_mapping=None): + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + max_len = config.neuron_config.max_length + + if ( + self.attention_chunk_size + and self.attention_chunk_size < max_len + and not layer_to_cache_size_mapping + ): + max_len = self.attention_chunk_size + elif self.sliding_window: + max_len = self.sliding_window + + if layer_to_cache_size_mapping: + layer_seq_lens = list(layer_to_cache_size_mapping) + else: + layer_seq_lens = [max_len] * len(self._layer_kv_configs) + + self.k_shapes = [] + self.v_shapes = [] + self.padded_layer_ids = [] + for idx, (kv_heads_per_rank, head_dim) in enumerate(self._layer_kv_configs): + cache_len = layer_seq_lens[idx] + k_shape, v_shape = get_kv_shapes( + cache_len, + max_batch_size, + kv_heads_per_rank, + head_dim, + self.k_cache_transposed, + self.is_kv_cache_tiled, + ) + self.k_shapes.append(k_shape) + self.v_shapes.append(v_shape) + + max_kv_heads = max(c[0] for c in self._layer_kv_configs) + max_head_dim = max(c[1] for c in self._layer_kv_configs) + self.k_shape, self.v_shape = get_kv_shapes( + max_len, + max_batch_size, + max_kv_heads, + max_head_dim, + self.k_cache_transposed, + self.is_kv_cache_tiled, + ) + + +# ==================================================================================== +# Top-level model +# ==================================================================================== + + +class NeuronGemma4TextModel(NeuronBaseModel): + """Gemma4 text decoder: scaled embeds + decoder layers + final norm + softcapped lm_head.""" + + def setup_attr_for_model(self, config: Gemma4InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + # Use the maximum KV head count (SWA = 8 on 26B-A4B) for base class. + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Gemma4InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = Gemma4ScaledEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + pad=True, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + ) + + updated_configs = get_updated_configs(config) + self.layers = nn.ModuleList( + [NeuronGemma4DecoderLayer(conf, idx) for idx, conf in enumerate(updated_configs)] + ) + + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + lm_head_linear = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + pad=True, + gather_output=not self.on_device_sampling, + dtype=config.neuron_config.torch_dtype, + ) + + self.final_logit_softcapping = getattr(config, "final_logit_softcapping", None) + if ( + self.final_logit_softcapping is not None + and self.final_logit_softcapping > 0 + ): + self.lm_head = SoftcappedLMHead(lm_head_linear, self.final_logit_softcapping) + else: + self.lm_head = lm_head_linear + + self.has_mixed_attn = True + self.sliding_window = config.sliding_window + + max_length = config.neuron_config.max_length + sw = config.sliding_window or max_length + self._uniform_cache_len = max(sw, max_length) + self.layer_to_cache_size_mapping = [self._uniform_cache_len] * config.num_hidden_layers + + def _create_windowed_attn_mask_tkg(self, attention_mask, window_size, position_ids): + """SWA TKG mask must match uniform KV cache size (PR #106 fix).""" + batch_size, _ = attention_mask.shape + cache_len = self._uniform_cache_len + + if cache_len == window_size: + return super()._create_windowed_attn_mask_tkg( + attention_mask, window_size, position_ids + ) + + pos = position_ids[:, 0] + idx = torch.arange(window_size, device=attention_mask.device).unsqueeze(0) + base_mask = (idx < pos.unsqueeze(1)) & (idx < window_size - 1) + + full_mask = torch.ones( + (batch_size, window_size), dtype=torch.bool, device=attention_mask.device + ) + full_mask[:, -1] = False + seq_less_than_window = pos < window_size - 1 + window_mask = torch.where( + seq_less_than_window.unsqueeze(1), base_mask, full_mask + ) + pad_len = cache_len - window_size + padded_mask = F.pad(window_mask, (0, pad_len), value=False) + return padded_mask[:, None, None, :] + + def _create_simple_attn_mask(self, attention_mask): + """Global mask must match uniform KV cache size (PR #106 fix).""" + batch_size = attention_mask.shape[0] + pad_len = self._uniform_cache_len - self.n_positions + if pad_len > 0: + attention_mask = F.pad(attention_mask, (0, pad_len), value=0) + return ( + attention_mask[:, None, None, :] + .expand(batch_size, 1, 1, self._uniform_cache_len) + .to(torch.bool) + ) + + def init_inference_optimization(self, config: Gemma4InferenceConfig): + if self.on_device_sampling: + try: + from neuronx_distributed_inference.modules.generation.sampling import ( + create_sampler, + ) + except ImportError: + from neuronx_distributed_inference.modules.sampling.utils import ( + create_sampler, + ) + + lm_head_tp_degree = None + if hasattr(self, "lm_head") and hasattr( + self.lm_head, "tensor_parallel_group" + ): + lm_head_tp_degree = self.lm_head.tensor_parallel_group.size() + self.sampler = create_sampler(config.neuron_config, lm_head_tp_degree) + + tp_degree = config.neuron_config.tp_degree + layer_kv_configs = [] + for i in range(config.num_hidden_layers): + layer_type = config.layer_types[i] + if layer_type == "sliding_attention": + kv_heads = config.num_key_value_heads + hd = config.head_dim + else: + kv_heads = config.num_global_key_value_heads + hd = config.global_head_dim + gqa_strategy = determine_sharding_strategy(tp_degree, kv_heads) + _, shardable_kv_heads = get_shardable_head_counts( + tp_degree, config.num_attention_heads, kv_heads, gqa_strategy + ) + kv_heads_per_rank = max(1, shardable_kv_heads // tp_degree) + layer_kv_configs.append((kv_heads_per_rank, hd)) + + self._layer_kv_configs = layer_kv_configs + self._max_kv_heads_per_rank = max(c[0] for c in layer_kv_configs) + self._max_head_dim = max(c[1] for c in layer_kv_configs) + + self.kv_mgr = Gemma4KVCacheManager( + config, + layer_kv_configs=layer_kv_configs, + global_rank=self.rank_util, + attention_chunk_size=self.attention_chunk_size, + sliding_window=self.sliding_window, + windowed_context_encoding_size=self.windowed_context_encoding_size, + layer_to_cache_size_mapping=self.layer_to_cache_size_mapping, + ) + + +# ==================================================================================== +# Causal-LM wrapper + state-dict converter +# ==================================================================================== + + +class NeuronGemma4ForCausalLM(NeuronBaseForCausalLM): + """Gemma4 causal LM for NeuronX inference. + + Handles weight conversion from HF checkpoint to NxDI naming, including: + * stripping `language_model.`, `model.` prefixes + * embed_tokens -> embed_tokens.embedding (ScaledEmbedding wrapper) + * q/k_norm -> q_layernorm/k_layernorm + * QK scaling correction (cancel NxDI's automatic 1/sqrt(head_dim)) + * `attention_k_eq_v` -> copy k_proj weights into v_proj for global layers + * tied lm_head (handle SoftcappedLMHead path) + * rank_util tensors for TP + + Plus the 26B-A4B-specific MoE keys: + * `layers.{i}.router.{proj.weight,scale,per_expert_scale}` -> + `layers.{i}.moe_block.moe.router.{...}` (router lives inside our + wrapper module after `initialize_moe_module`). + * `layers.{i}.experts.gate_up_proj` (HF shape `(E, 2*I, H)`) -> + `layers.{i}.moe_block.moe.expert_mlps.mlp_op.gate_up_proj.weight` + (NxDI shape `(E, H, 2*I)` — last two dims transposed). + * `layers.{i}.experts.down_proj` (HF shape `(E, H, I)`) -> + `layers.{i}.moe_block.moe.expert_mlps.mlp_op.down_proj.weight` + (NxDI shape `(E, I, H)` — last two dims transposed). + """ + + _model_cls = NeuronGemma4TextModel + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import Gemma4ForConditionalGeneration # type: ignore[import-not-found] + + return Gemma4ForConditionalGeneration.from_pretrained(model_path, **kwargs) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: Dict[str, torch.Tensor], + config: Gemma4InferenceConfig, + ) -> Dict[str, torch.Tensor]: + neuron_config = config.neuron_config + tp_degree = neuron_config.tp_degree + new_state_dict = {} + + for key, weight in state_dict.items(): + new_key = key + + if new_key.startswith("language_model.model."): + new_key = new_key[len("language_model.model."):] + elif new_key.startswith("language_model."): + new_key = new_key[len("language_model."):] + elif new_key.startswith("model.language_model.model."): + new_key = new_key[len("model.language_model.model."):] + elif new_key.startswith("model.language_model."): + new_key = new_key[len("model.language_model."):] + elif new_key.startswith("model."): + new_key = new_key[len("model."):] + + # Skip vision/audio/multimodal weights — text-only port for now. + if ( + "vision_tower." in new_key + or "multi_modal_projector." in new_key + or "embed_vision." in new_key + or "audio_tower." in new_key + or "embed_audio." in new_key + ): + continue + + if new_key == "embed_tokens.weight": + new_key = "embed_tokens.embedding.weight" + + new_key = new_key.replace(".self_attn.q_norm.", ".self_attn.q_layernorm.") + new_key = new_key.replace(".self_attn.k_norm.", ".self_attn.k_layernorm.") + + new_state_dict[new_key] = weight.detach().clone() + + # Per-layer transformations + for i in range(config.num_hidden_layers): + layer_type = config.layer_types[i] + is_global = layer_type == "full_attention" + + if is_global: + hd = config.global_head_dim + else: + hd = config.head_dim + + prefix = f"layers.{i}.self_attn" + + # QK scaling: gemma4 uses scaling=1.0 (no 1/sqrt(head_dim)). NxDI + # always applies 1/sqrt(head_dim). Pre-scale q_layernorm.weight by + # sqrt(head_dim) so the effects cancel after RMSNorm scale-invariance. + q_norm_key = f"{prefix}.q_layernorm.weight" + if q_norm_key in new_state_dict: + scaling_factor = math.sqrt(float(hd)) + orig_dtype = new_state_dict[q_norm_key].dtype + new_state_dict[q_norm_key] = ( + new_state_dict[q_norm_key].to(torch.float32) * scaling_factor + ).to(orig_dtype) + + # attention_k_eq_v: copy K weights to V for global layers (no v_proj in HF). + if is_global and getattr(config, "attention_k_eq_v", False): + k_key = f"{prefix}.k_proj.weight" + v_key = f"{prefix}.v_proj.weight" + if k_key in new_state_dict and v_key not in new_state_dict: + new_state_dict[v_key] = new_state_dict[k_key].detach().clone() + + new_state_dict[f"{prefix}.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + # MoE remap: HF stores router/experts directly under the + # decoder layer. Our `NeuronGemma4DecoderLayer` nests them as + # `moe_block.moe.{router,expert_mlps.mlp_op}`. Skip silently + # if the layer has no MoE keys (e.g. dense smoke-compile run + # or a future variant that toggles MoE per-layer). + moe_src_prefix = f"layers.{i}." + moe_dst_prefix = f"layers.{i}.moe_block.moe." + router_keys = [ + ("router.proj.weight", "router.proj.weight"), + ("router.scale", "router.scale"), + ("router.per_expert_scale", "router.per_expert_scale"), + ] + for src_suffix, dst_suffix in router_keys: + src_key = moe_src_prefix + src_suffix + if src_key in new_state_dict: + new_state_dict[moe_dst_prefix + dst_suffix] = new_state_dict.pop( + src_key + ) + + gate_up_src = moe_src_prefix + "experts.gate_up_proj" + gate_up_dst = moe_dst_prefix + "expert_mlps.mlp_op.gate_up_proj.weight" + if gate_up_src in new_state_dict: + # HF shape: (E, 2*I, H); NxDI shape: (E, H, 2*I). + w = new_state_dict.pop(gate_up_src) + new_state_dict[gate_up_dst] = w.transpose(1, 2).contiguous() + + down_src = moe_src_prefix + "experts.down_proj" + down_dst = moe_dst_prefix + "expert_mlps.mlp_op.down_proj.weight" + if down_src in new_state_dict: + # HF shape: (E, H, I); NxDI shape: (E, I, H). + w = new_state_dict.pop(down_src) + new_state_dict[down_dst] = w.transpose(1, 2).contiguous() + + if neuron_config.vocab_parallel: + new_state_dict["embed_tokens.embedding.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + new_state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + return new_state_dict + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + """Tied weights: embed_tokens -> lm_head (handle SoftcappedLMHead path).""" + embed_key = None + if "embed_tokens.embedding.weight" in state_dict: + embed_key = "embed_tokens.embedding.weight" + elif "embed_tokens.weight" in state_dict: + embed_key = "embed_tokens.weight" + + if embed_key is not None: + weight = state_dict[embed_key].clone() + state_dict["lm_head.weight"] = weight + state_dict["lm_head.linear.weight"] = weight.clone() + + @classmethod + def get_config_cls(cls): + return Gemma4InferenceConfig + + +__all__ = [ + "Gemma4InferenceConfig", + "Gemma4NeuronConfig", + "Gemma4KVCacheManager", + "Gemma4RMSNorm", + "Gemma4ScaledEmbedding", + "Gemma4VNorm", + "NeuronGemma4Attention", + "NeuronGemma4DecoderLayer", + "NeuronGemma4ForCausalLM", + "NeuronGemma4MLP", + "NeuronGemma4MoEBlock", + "NeuronGemma4Router", + "NeuronGemma4TextModel", + "SoftcappedLMHead", + "get_updated_configs", +] diff --git a/contrib/models/gemma-4-26b-a4b-it/src/ndxi_patch.py b/contrib/models/gemma-4-26b-a4b-it/src/ndxi_patch.py new file mode 100644 index 00000000..a1d43746 --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/src/ndxi_patch.py @@ -0,0 +1,505 @@ +from typing import Callable, List, Optional, Tuple, Union +import math +import logging + +from neuronx_distributed_inference.utils.tensor_replacement.registry import ( + TensorReplacementRegister, +) +import torch +from transformers.modeling_outputs import CausalLMOutputWithPast + +logger = logging.getLogger(__name__) + + +def patched_get_last_kv_window( + window_size, + position_ids, + latest_k, + latest_v, + windowed_context_encoding_window_idx=-1, + spec_len=0, +): + """ + Replaces https://github.com/aws-neuron/neuronx-distributed-inference/blob/main/src/neuronx_distributed_inference/modules/attention/utils.py#L634 + to convert the index tensor in torch.gather to a LongTensor. Otherwise, the function will error out. + """ + batch_size, num_head, _, head_dim = latest_k.shape + latest_pos = torch.amax(position_ids, dim=1) + if ( + windowed_context_encoding_window_idx >= 1 + ): # if windowed cte, account for current window offset + latest_pos -= windowed_context_encoding_window_idx * window_size + + # True window size + window_size = window_size - 1 + spec_len - 1 if spec_len > 0 else window_size - 1 + + end_idx = (latest_pos + 1).clamp(min=window_size) + start_idx = (end_idx - window_size).clamp(min=0) + orig_indices = start_idx[:, None] + torch.arange(window_size) + + # Calculate per-batch left shifts + left_shifts = (window_size - (end_idx % window_size)) % window_size + base = torch.arange(window_size).expand(batch_size, window_size) + shifted_idx = (base + left_shifts[:, None]) % window_size + + # Determine per-batch shifted gather indices + gather_idx = torch.gather(orig_indices, dim=1, index=shifted_idx.long()) + gather_idx = ( + gather_idx[:, None, :, None] + .expand(batch_size, num_head, window_size, head_dim) + .to(device=latest_k.device) + ) + + # Gather to create non-physically contiguous KV cache + latest_k = torch.gather(latest_k, dim=2, index=gather_idx.long()) + latest_v = torch.gather(latest_v, dim=2, index=gather_idx.long()) + return latest_k, latest_v + + +def patched_base_image_to_text_model_forward( + self, + input_ids: torch.LongTensor = None, + seq_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + sampling_params: Optional[torch.FloatTensor] = None, + prev_hidden: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + adapter_ids: Optional[torch.LongTensor] = None, + medusa_args=None, + return_dict: Optional[bool] = None, + llava_args: Optional[List] = [], + input_capture_hook: Optional[Callable] = None, + slot_mapping: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + full_context_lens: Optional[torch.LongTensor] = None, + computed_context_lens: Optional[torch.LongTensor] = None, + vision_embeddings: Optional[torch.FloatTensor] = None, + vision_mask: Optional[torch.BoolTensor] = None, + tensor_capture_hook: Optional[ + Callable + ] = None, # Missing argument that triggers a NameError +) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + # infer attention_mask from position_ids if not provided + if attention_mask is None: + attention_mask = self._infer_attention_mask(position_ids) + + if seq_ids is None: + seq_ids = torch.arange(input_ids.shape[0]) + + input_ids, attention_mask, position_ids, seq_ids, sampling_params = ( + self.preprocess_inputs( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + sampling_params=sampling_params, + prev_hidden=prev_hidden, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + adapter_ids=adapter_ids, + medusa_args=medusa_args, + return_dict=return_dict, + llava_args=llava_args, + input_capture_hook=input_capture_hook, + slot_mapping=slot_mapping, + block_table=block_table, + full_context_lens=full_context_lens, + computed_context_lens=computed_context_lens, + ) + ) + + # Bypass _get_model_outputs entirely. NxDI 0.8.0 added a + # deepstack_vision_embeds arg that gets forwarded to the CTE/TKG + # models, but the ImageToTextModelWrapper.input_generator only + # traces 24 inputs (no deepstack). Calling the models directly + # with exactly 24 positional args avoids the mismatch. + _empty = torch.empty(0) + + if self._is_prefill(position_ids): + # Prefill: vision tensors must match the traced shapes even for + # text-only inputs. The CTE NEFF was traced with + # vision_embeddings=[batch, seq_len, hidden_size] + # vision_mask=[batch, seq_len, 1] + # so we create zero-filled tensors when they are not provided. + batch_size = input_ids.shape[0] + n_active = input_ids.shape[1] # == bucket seq_len + if vision_embeddings is None or vision_embeddings.numel() == 0: + dtype = getattr(self.config, "neuron_config", None) + dtype = dtype.torch_dtype if dtype is not None else torch.bfloat16 + vision_embeddings = torch.zeros( + batch_size, n_active, self.config.hidden_size, dtype=dtype + ) + if vision_mask is None or vision_mask.numel() == 0: + vision_mask = torch.zeros(batch_size, n_active, 1, dtype=torch.int32) + + outputs = self.context_encoding_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + _empty, # prev_hidden + _empty, # adapter_ids + _empty, # accepted_indices + _empty, # current_length + _empty, # medusa_mask + _empty, # scatter_index + _empty, # slot_mapping + _empty, # active_block_table + _empty, # num_queries + _empty, # computed_context_lens + _empty, # tile_q_indices + _empty, # tile_block_tables + _empty, # tile_masks + _empty, # inputs_embeds + _empty, # kv_cache + _empty, # active_mask + _empty, # rotary_position_ids + vision_embeddings, + vision_mask, + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + # Token generation: vision tensors must be empty (traced as [0]). + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + _empty, # prev_hidden + _empty, # adapter_ids + _empty, # accepted_indices + _empty, # current_length + _empty, # medusa_mask + _empty, # scatter_index + _empty, # slot_mapping + _empty, # active_block_table + _empty, # num_queries + _empty, # computed_context_lens + _empty, # tile_q_indices + _empty, # tile_block_tables + _empty, # tile_masks + _empty, # inputs_embeds + _empty, # kv_cache + _empty, # active_mask + _empty, # rotary_position_ids + _empty, # vision_embeddings (empty for TKG) + _empty, # vision_mask (empty for TKG) + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + generation_model = self.get_generation_model() + if not generation_model.is_neuron(): + self._copy_past_key_values(outputs) + + # Process outputs + constructed_outputs = self._get_constructed_outputs(outputs, is_run_on_neuron) + + # Apply tensor_capture_hook if provided and tensors are captured + if tensor_capture_hook and constructed_outputs.captured_tensors: + # Apply the hook if captured tensors are found + tensor_capture_hook(self, constructed_outputs.captured_tensors) + + return constructed_outputs + + +def patched_hf_adapter_prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + sampling_params=None, + adapter_ids=None, + **kwargs, +): + # Store KV cache flag before forward pass. + self.prev_kv_cache_populated = self.neuron_model.kv_cache_populated + if self.neuron_model.kv_cache_populated: + input_ids = input_ids[:, -1:] + + accepted_indices = kwargs.get("accepted_indices", None) + current_length = kwargs.get("current_length", None) + medusa_mask = kwargs.get("medusa_mask", None) + scatter_index = kwargs.get("scatter_index", None) + position_ids = kwargs.get("position_ids", None) + input_capture_hook = kwargs.get("input_capture_hook", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + if self.input_start_offsets: + if len(self.input_start_offsets) > 1: + position_ids += torch.tensor( + self.input_start_offsets, + dtype=position_ids.dtype, + device=position_ids.device, + )[:, None] + else: + position_ids += self.input_start_offsets[0] + for i, offset in enumerate(self.input_start_offsets): + position_ids[i, 0:offset] = torch.arange(offset) + else: + position_ids.masked_fill_(attention_mask == 0, 1) + + if self.neuron_model.kv_cache_populated: + position_ids = torch.amax(position_ids, 1, keepdim=True) + position_ids = position_ids + 1 + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", False), + "attention_mask": attention_mask, + "medusa_args": ( + accepted_indices, + current_length, + medusa_mask, + scatter_index, + ), + "sampling_params": sampling_params, + "input_capture_hook": input_capture_hook, + # "tensor_capture_hook": tensor_capture_hook, -> FIX: Otherwise raises a breaking NameError + "adapter_ids": adapter_ids, + } + ) + + tf_args = [] + if self.neuron_config.tensor_replacement_config: + if hasattr(self, "generation_step"): + self.generation_step += 1 + else: + self.generation_step = 1 + reg = TensorReplacementRegister.get_instance() + tf, masks = reg.step_args(self.generation_step) + tf_args = tf + masks + + # Only add tf_args if not empty + if tf_args: + model_inputs["tf_args"] = tf_args + + # WARNING: This is needed for propagating additional kwargs to the neuron model + additional_kwargs = self.neuron_model.get_required_kwargs() + for arg in additional_kwargs: + model_inputs.update({arg: kwargs.get(arg, None)}) + + return model_inputs + + +# --------------------------------------------------------------------------- +# NKI Flash Attention kernel integration for head_dim > 128 +# --------------------------------------------------------------------------- +# Lazy-loaded kernel reference (compiled on first use) +_nki_flash_attn_kernel = None + + +def _get_nki_flash_attn_kernel(): + """Lazy-load and JIT-compile the NKI flash attention kernel.""" + global _nki_flash_attn_kernel + if _nki_flash_attn_kernel is not None: + return _nki_flash_attn_kernel + + try: + # Prefer relative import when this module ships inside `neuron_port`. + from .nki_flash_attn_large_d import flash_attn_large_d + except ImportError: + # Fallback for the original PR #106 layout. + from nki_flash_attn_large_d import flash_attn_large_d + + _nki_flash_attn_kernel = flash_attn_large_d + return _nki_flash_attn_kernel + + +def _nki_kernel_perform_prefill(self, Q, K, V, q_len, bsz, attention_mask): + """ + Replacement for perform_prefill that uses our custom NKI kernel + when head_dim > 128. Falls back to the original method otherwise. + + Input: Q, K, V in BHSD layout [batch, num_heads, seq_len, head_dim] + Output: (attn_output in BHDS layout [batch, num_heads, head_dim, seq_len], + FlashAttentionStrategy.UNSHARDED_KERNEL) + """ + from neuronx_distributed_inference.modules.attention.attention_base import ( + FlashAttentionStrategy, + ) + from neuronx_distributed_inference.modules.attention.attention_base import repeat_kv + + if self.head_dim <= 128: + return self._orig_perform_prefill(Q, K, V, q_len, bsz, attention_mask) + + # head_dim > 128: use our NKI kernel + kernel = _get_nki_flash_attn_kernel() + + # Q is BHSD: (bsz, num_heads, q_len, head_dim) + # Reshape to (bsz * num_heads, q_len, head_dim) for kernel (tp_q=True layout) + Q_3d = Q.reshape(bsz * self.num_heads, q_len, self.head_dim).to(self.torch_dtype) + + # GQA: replicate K/V heads if needed, then reshape + num_kv_heads = self.num_key_value_heads + K_active = K # already (bsz, num_kv_heads, q_len, head_dim) + V_active = V + K_3d = K_active.reshape(bsz * num_kv_heads, q_len, self.head_dim).to( + self.torch_dtype + ) + V_3d = V_active.reshape(bsz * num_kv_heads, q_len, self.head_dim).to( + self.torch_dtype + ) + + # NxDI already applies 1/sqrt(head_dim) scaling to Q in scaled_qk, + # but for the kernel path it's applied before the kernel call (line 788) + Q_3d = Q_3d / math.sqrt(self.head_dim) + + # Determine sliding window + sw = self.sliding_window if self.sliding_window else 0 + + # Grid: one program per KV group (kernel handles Q-head fan-out internally) + grid_bs = bsz * num_kv_heads + + # Call kernel + # kernel expects: q(bs, seq, d), k(bs_kv, seq, d), v(bs_kv, seq, d) + # returns: o(bs, d, seq) + attn_output = kernel[grid_bs]( + Q_3d, + K_3d, + V_3d, + scale=1.0, # scaling already applied to Q + use_causal_mask=(attention_mask is not None), + sliding_window=sw, + ) + + # Reshape output from (bsz * num_heads, head_dim, q_len) -> (bsz, num_heads, head_dim, q_len) = BHDS + attn_output = attn_output.reshape(bsz, self.num_heads, self.head_dim, q_len) + + return attn_output, FlashAttentionStrategy.UNSHARDED_KERNEL + + +def _nki_kernel_perform_prefill_windowed_attn( + self, Q, K, V, q_len, bsz, attention_mask, window_size +): + """ + Replacement for perform_prefill_windowed_attn that uses our custom NKI kernel + when head_dim > 128. Falls back to the original method otherwise. + + Input: Q, K, V in BHSD layout [batch, num_heads, seq_len, head_dim] + Output: (attn_output in BHDS layout, FlashAttentionStrategy.UNSHARDED_KERNEL) + """ + from neuronx_distributed_inference.modules.attention.attention_base import ( + FlashAttentionStrategy, + ) + from neuronx_distributed_inference.modules.attention.attention_base import repeat_kv + + if self.head_dim <= 128: + return self._orig_perform_prefill_windowed_attn( + Q, K, V, q_len, bsz, attention_mask, window_size + ) + + # head_dim > 128: use our NKI kernel with sliding window + kernel = _get_nki_flash_attn_kernel() + + Q_3d = Q.reshape(bsz * self.num_heads, q_len, self.head_dim).to(self.torch_dtype) + + # For windowed attn, K/V are already replicated by the caller + K_active = repeat_kv(K, self.num_key_value_groups) + V_active = repeat_kv(V, self.num_key_value_groups) + K_3d = K_active.reshape(bsz * self.num_heads, q_len, self.head_dim).to( + self.torch_dtype + ) + V_3d = V_active.reshape(bsz * self.num_heads, q_len, self.head_dim).to( + self.torch_dtype + ) + + Q_3d = Q_3d / math.sqrt(self.head_dim) + + sw = window_size if window_size else 0 + + grid_bs = bsz * self.num_heads # After repeat_kv, all heads are present + + attn_output = kernel[grid_bs]( + Q_3d, + K_3d, + V_3d, + scale=1.0, + use_causal_mask=True, + sliding_window=sw, + ) + + attn_output = attn_output.reshape(bsz, self.num_heads, self.head_dim, q_len) + + return attn_output, FlashAttentionStrategy.UNSHARDED_KERNEL + + +def _patch_attention_modules_for_nki_kernel(): + """ + Monkey-patch NeuronAttentionBase.perform_prefill and + perform_prefill_windowed_attn to use our NKI kernel when head_dim > 128. + + This is called once at import time. The original methods are preserved + as _orig_perform_prefill and _orig_perform_prefill_windowed_attn. + """ + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, + ) + + # Save originals + NeuronAttentionBase._orig_perform_prefill = NeuronAttentionBase.perform_prefill + NeuronAttentionBase._orig_perform_prefill_windowed_attn = ( + NeuronAttentionBase.perform_prefill_windowed_attn + ) + + # Replace with our wrappers + NeuronAttentionBase.perform_prefill = _nki_kernel_perform_prefill + NeuronAttentionBase.perform_prefill_windowed_attn = ( + _nki_kernel_perform_prefill_windowed_attn + ) + + logger.info("NKI flash attention kernel patch applied for head_dim > 128") + + +def apply_patch() -> None: + import neuronx_distributed_inference.modules.attention.utils as u + + u.get_last_kv_window = patched_get_last_kv_window + + import neuronx_distributed_inference.models.image_to_text_model_base as mm_base + + mm_base.NeuronBaseForImageToText.forward = patched_base_image_to_text_model_forward + + # Patch attention for NKI kernel with head_dim > 128 + _patch_attention_modules_for_nki_kernel() + + try: + import neuronx_distributed_inference.utils.hf_adapter as hf_adapter + + hf_adapter.HuggingFaceGenerationAdapter.prepare_inputs_for_generation = ( + patched_hf_adapter_prepare_inputs_for_generation + ) + except ImportError: + # hf_adapter may fail to import if transformers API changed + # (e.g., SampleDecoderOnlyOutput renamed). This patch is only + # needed for HF generate() integration, not core inference. + pass diff --git a/contrib/models/gemma-4-26b-a4b-it/src/nki_flash_attn_d256_swa.py b/contrib/models/gemma-4-26b-a4b-it/src/nki_flash_attn_d256_swa.py new file mode 100644 index 00000000..2f838f5a --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/src/nki_flash_attn_d256_swa.py @@ -0,0 +1,950 @@ +""" +Flash attention for d=256 with sliding window mask, optimized for Gemma4 E2B SWA layers. + +Adapted from Qwen3-Coder-Next nki_flash_attn_d256_pipe.py with: + - Sliding window mask (window_size=512) replacing full causal mask + - Tile-skip optimization: only K tiles within the window are processed + - No fused RoPE (E2B applies RoPE externally in prep_qkv_tensors) + - NKI 0.3.0+ API (affine_select keyword offset) + +Architecture: 2x128 QK tiling for head_dim=256, 3-stage software pipeline. +Called per (batch, kv_head) with pre-sliced post-RoPE Q/K/V in BHSD layout. + +E2B SWA layer specifics: + - head_dim = 256 + - num_kv_heads = 1 (GQA, 8 Q heads per KV head at full model, varies by TP) + - sliding_window = 512 + - Q/K already have RoPE applied (theta=10000, full rotation) + - V already has v_norm applied (V / RMS(V)) + +Layouts (per-call, single batch + single kv_head): + Q: (1, q_h_per_k_h, seq_q, 256) -- BHSD + K: (1, 1, seq_k, 256) -- BHSD + V: (1, 1, seq_v, 256) -- BHSD + O: (1, q_h_per_k_h, seq_q, 256) -- BHSD + +Internal SBUF layout after DMA transpose of Q/K: + Q_sb: (D_TILE=128, Q_GRP_SZ=128) -- d on partition, seq on free + K_sb: (D_TILE=128, K_TILE_SZ=512) -- d on partition, seq on free + V_sb: (V_TILE_SZ=128, D_HEAD=256) -- seq on partition, d on free +""" + +import os + +os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", "trn2") + +import math +import nki.isa as nisa +import nki.language as nl +import nki + +# ============================================================================ +# Constants +# ============================================================================ +D_HEAD = 256 +D_TILE = 128 # partition dim tile for d-tiling (256 = 2 x 128) +Q_GRP_SZ = 128 # Q group size = partition dim max +K_TILE_SZ = 512 # K tile size for MM1 (free dim of K in matmul) +V_TILE_SZ = 128 # V tile size for MM2 (partition dim of transposed P) +LARGE_TILE_SZ = 2048 # Large tile grouping +EXP_TILE_SZ = 512 # Exp tile for activation_reduce +PSUM_FMAX = 512 # PSUM free dimension max +FLOAT32_MIN = -3.4028235e38 + + +# ============================================================================ +# ModularAllocator helpers +# ============================================================================ + + +def _align32(addr): + """Round up address to 32-byte alignment (required for DMA transpose).""" + return (addr + 31) // 32 * 32 + + +def _alloc_modular_1d(shape, dtype, block_dim, num_free_tiles, base_addr): + """Allocate 1D modular buffer list.""" + base_addr = _align32(base_addr) + tile_elems = 1 + for d in shape[1:]: + tile_elems *= d + dtype_size = 4 if dtype == nl.float32 else 2 + tile_bytes = _align32(tile_elems * dtype_size) + + tensors = [] + for i in range(block_dim): + tensors.append(nl.ndarray(shape, dtype=dtype, buffer=nl.sbuf)) + next_addr = base_addr + num_free_tiles * tile_bytes + return tensors, next_addr + + +def _alloc_modular_2d( + shape, dtype, block_dim0, block_dim1, num_free0, num_free1, base_addr +): + """Allocate 2D modular buffer.""" + base_addr = _align32(base_addr) + tile_elems = 1 + for d in shape[1:]: + tile_elems *= d + dtype_size = 4 if dtype == nl.float32 else 2 + tile_bytes = _align32(tile_elems * dtype_size) + + tensors = [] + for i in range(block_dim0): + row = [] + for j in range(block_dim1): + row.append(nl.ndarray(shape, dtype=dtype, buffer=nl.sbuf)) + tensors.append(row) + next_addr = base_addr + num_free0 * num_free1 * tile_bytes + return tensors, next_addr + + +def _alloc_modular_3d(shape, dtype, dims, n_free, base_addr): + """Allocate 3D modular buffer.""" + base_addr = _align32(base_addr) + tile_elems = 1 + for d in shape[1:]: + tile_elems *= d + dtype_size = 4 if dtype == nl.float32 else 2 + tile_bytes = _align32(tile_elems * dtype_size) + + tensors = [] + for i in range(dims[0]): + layer = [] + for j in range(dims[1]): + row = [] + for k in range(dims[2]): + row.append(nl.ndarray(shape, dtype=dtype, buffer=nl.sbuf)) + layer.append(row) + tensors.append(layer) + total_physical = n_free[0] * n_free[1] * n_free[2] + next_addr = base_addr + total_physical * tile_bytes + return tensors, next_addr + + +# ============================================================================ +# Pipeline stage functions +# ============================================================================ + + +def _pipe_load_q( + grp_i, + q_sb_lo, + q_sb_hi, + q_hbm, + d_tile, + seqlen_q, + batch_id, + q_head_idx, + n_heads, + d_head, +): + """Load Q group from BHSD HBM into SBUF with DMA transpose to (D, S) layout.""" + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + + q_offset = ( + batch_id * n_heads * seqlen_q * d_head + + q_head_idx * seqlen_q * d_head + + q_start * d_head + ) + + # Lo half: D[0:128] + nisa.dma_transpose( + dst=q_sb_lo[grp_i].ap([[Q_GRP_SZ, d_tile], [1, 1], [1, 1], [1, num_q]]), + src=q_hbm.ap( + [[d_head, num_q], [1, 1], [1, 1], [1, d_tile]], + offset=q_offset, + ), + ) + # Hi half: D[128:256] + nisa.dma_transpose( + dst=q_sb_hi[grp_i].ap([[Q_GRP_SZ, d_tile], [1, 1], [1, 1], [1, num_q]]), + src=q_hbm.ap( + [[d_head, num_q], [1, 1], [1, 1], [1, d_tile]], + offset=q_offset + d_tile, + ), + ) + + +def _pipe_qk_and_max( + grp_i, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + window_size, +): + """Compute QK^T with d=256 tiling, sliding window mask, scale, and row-wise max. + + Sliding window mask: keep when q_pos - window_size < k_pos <= q_pos + i.e., k_start + f <= q_start + p (upper bound, same as causal) + AND k_start + f > q_start + p - window_size (lower bound) + + Combined as affine_select with causal upper bound, plus tile-level skip + for K tiles entirely outside the window. + """ + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + num_k_tiles_per_large = LARGE_TILE_SZ // K_TILE_SZ # 4 + + # Initialize partial max to -inf + nisa.memset(mm1_partial_max[grp_i][...], value=FLOAT32_MIN) + + # Initialize mm1_masked to -inf (skipped K tiles → exp=0) + for lt_idx in range(num_large_tiles): + nisa.memset(mm1_masked[grp_i][lt_idx][...], value=FLOAT32_MIN) + + for large_tile_idx in range(num_large_tiles): + for k_tile_local in range(num_k_tiles_per_large): + k_tile_idx = large_tile_idx * num_k_tiles_per_large + k_tile_local + if k_tile_idx >= num_k_tiles: + continue + + k_start = k_tile_idx * K_TILE_SZ + num_k = min(seqlen_kv - k_start, K_TILE_SZ) + if num_k <= 0: + continue + + # Tile-level skip: sliding window bounds + # Upper bound: q_last < k_start means all Q positions before this K tile + q_last = q_start + num_q - 1 + if q_last < k_start: + continue + + # Lower bound: k_end <= q_first - window means K tile entirely before window + q_first = q_start + k_end = k_start + num_k - 1 + if k_end < q_first - window_size + 1: + continue + + # MM1: QK = Q_lo^T @ K_lo + Q_hi^T @ K_hi + psum_tile = mm1_psum[grp_i][large_tile_idx][k_tile_local] + + # First half: d[0:128] + nisa.nc_matmul( + psum_tile[:num_q, :num_k], + q_sb_lo[grp_i][:D_TILE, :num_q], + k_sb_lo[k_tile_idx][:D_TILE, :num_k], + ) + # Second half: d[128:256] — accumulates into same PSUM + nisa.nc_matmul( + psum_tile[:num_q, :num_k], + q_sb_hi[grp_i][:D_TILE, :num_q], + k_sb_hi[k_tile_idx][:D_TILE, :num_k], + ) + + # Copy PSUM -> temp SBUF (unscaled) + nisa.tensor_copy( + mm1_copy_sb[:num_q, :num_k], + psum_tile[:num_q, :num_k], + ) + + # Sliding window mask via affine_select (NKI 0.3.0 API) + # Upper bound (causal): keep when (k_start+f) <= (q_start+p) + # Pattern: (-1)*p + (1)*f + offset >= 0 => f <= p + offset + # offset = q_start - k_start + # This masks future tokens (same as causal) + nisa.affine_select( + dst=mm1_asel_sb[:num_q, :num_k], + pattern=[[-1, num_k]], + channel_multiplier=1, + on_true_tile=mm1_copy_sb[:num_q, :num_k], + on_false_value=FLOAT32_MIN, + offset=q_start - k_start, + cmp_op=nl.greater_equal, + ) + + # Lower bound: keep when k_pos >= q_pos - window_size + 1 + # where k_pos = k_start + f, q_pos = q_start + p + # => (k_start + f) >= (q_start + p) - window_size + 1 + # => f >= p + (q_start - k_start) - window_size + 1 + # => f >= p - (k_start - q_start + window_size - 1) + # + # affine_select with ch_mul=-1, pattern=[[1, num_k]]: + # val = (-1)*p + (1)*f + offset >= 0 + # => f >= p - offset + # So offset = k_start - q_start + window_size - 1 + lower_offset = k_start - q_start + window_size - 1 + nisa.affine_select( + dst=mm1_asel_sb[:num_q, :num_k], + pattern=[[1, num_k]], + channel_multiplier=-1, + on_true_tile=mm1_asel_sb[:num_q, :num_k], + on_false_value=FLOAT32_MIN, + offset=lower_offset, + cmp_op=nl.greater_equal, + ) + + # Scale + max extraction + nisa.tensor_scalar_reduce( + mm1_masked[grp_i][large_tile_idx][ + :num_q, nl.ds(k_tile_local * K_TILE_SZ, num_k) + ], + data=mm1_asel_sb[:num_q, :num_k], + op0=nl.multiply, + operand0=scale, + reduce_op=nl.maximum, + reduce_res=mm1_partial_max[grp_i][:num_q, k_tile_idx], + ) + + +def _pipe_update_max( + grp_i, mm1_partial_max, mm1_section_max, mm1_running_max, num_k_tiles, seqlen_q +): + """Compute section max from partial maxes, store as -max (negated).""" + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + + nisa.tensor_reduce( + mm1_section_max[grp_i][:num_q, 0], + nl.maximum, + mm1_partial_max[grp_i][:num_q, :num_k_tiles], + 1, + negate=True, + ) + + nisa.tensor_copy(mm1_running_max[:num_q, grp_i], mm1_section_max[grp_i][:num_q, 0]) + + +def _pipe_exp( + grp_i, + mm1_masked, + mm1_running_max, + exp_sb, + exp_partial_sum, + exp_tp_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_k_tiles, +): + """Compute exp(S - max), partial sums, and DMA transpose for MM2.""" + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + num_exp_per_large = LARGE_TILE_SZ // EXP_TILE_SZ # 4 + + nisa.memset(exp_partial_sum[grp_i][...], value=0.0) + + for large_tile_idx in range(num_large_tiles): + for exp_tile_idx in range(num_exp_per_large): + kv_start = large_tile_idx * LARGE_TILE_SZ + exp_tile_idx * EXP_TILE_SZ + num_kv = min(seqlen_kv - kv_start, EXP_TILE_SZ) + if num_kv <= 0: + continue + + nisa.activation_reduce( + exp_sb[grp_i][large_tile_idx][ + :num_q, nl.ds(exp_tile_idx * EXP_TILE_SZ, num_kv) + ], + op=nl.exp, + data=mm1_masked[grp_i][large_tile_idx][ + :num_q, nl.ds(exp_tile_idx * EXP_TILE_SZ, num_kv) + ], + reduce_op=nl.add, + reduce_res=exp_partial_sum[grp_i][ + :num_q, + large_tile_idx * num_exp_per_large + exp_tile_idx, + ], + bias=mm1_running_max[:num_q, grp_i], + ) + + # DMA transpose: exp_sb[Q=128, KV=512] -> exp_tp_sb[KV=128, Q=512] + num_kv_outer = num_kv // V_TILE_SZ + num_kv_inner = num_kv % V_TILE_SZ + + if num_kv_outer >= 1: + nisa.dma_transpose( + dst=exp_tp_sb[grp_i][large_tile_idx][exp_tile_idx].ap( + [ + [K_TILE_SZ, V_TILE_SZ], + [1, 1], + [V_TILE_SZ, num_kv_outer], + [1, num_q], + ] + ), + src=exp_sb[grp_i][large_tile_idx].ap( + [ + [LARGE_TILE_SZ, num_q], + [1, 1], + [V_TILE_SZ, num_kv_outer], + [1, V_TILE_SZ], + ], + offset=exp_tile_idx * K_TILE_SZ, + ), + ) + + if num_kv_inner > 0: + nisa.dma_transpose( + dst=exp_tp_sb[grp_i][large_tile_idx][exp_tile_idx].ap( + [ + [K_TILE_SZ, num_kv_inner], + [1, 1], + [V_TILE_SZ, 1], + [1, num_q], + ], + offset=num_kv_outer * V_TILE_SZ, + ), + src=exp_sb[grp_i][large_tile_idx].ap( + [ + [LARGE_TILE_SZ, num_q], + [1, 1], + [V_TILE_SZ, 1], + [1, num_kv_inner], + ], + offset=exp_tile_idx * K_TILE_SZ + num_kv_outer * V_TILE_SZ, + ), + ) + + +def _pipe_pv( + grp_i, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_v_tiles, +): + """Compute P@V (MM2) with d=256 split into lo/hi halves.""" + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + num_mm2_grps_per_large = LARGE_TILE_SZ // K_TILE_SZ # 4 + num_mm2_per_grp = K_TILE_SZ // V_TILE_SZ # 4 + + nisa.memset(mm2_sb[grp_i][...], value=0.0) + + for large_tile_idx in range(num_large_tiles): + psum_tile_lo = mm2_psum_lo[grp_i][large_tile_idx] + psum_tile_hi = mm2_psum_hi[grp_i][large_tile_idx] + + for mm2_grp_i in range(num_mm2_grps_per_large): + exp_tp_tile = exp_tp_sb[grp_i][large_tile_idx][mm2_grp_i] + + for mm2_i in range(num_mm2_per_grp): + v_tile_idx = ( + large_tile_idx * num_mm2_grps_per_large * num_mm2_per_grp + + mm2_grp_i * num_mm2_per_grp + + mm2_i + ) + kv_start = v_tile_idx * V_TILE_SZ + num_kv = min(seqlen_kv - kv_start, V_TILE_SZ) + if num_kv <= 0 or v_tile_idx >= num_v_tiles: + continue + + # MM2 lo: exp_tp^T @ V[:, :128] + nisa.nc_matmul( + psum_tile_lo[:num_q, :D_TILE], + exp_tp_tile[:num_kv, nl.ds(mm2_i * V_TILE_SZ, num_q)], + v_sb[v_tile_idx][:num_kv, :D_TILE], + ) + # MM2 hi: exp_tp^T @ V[:, 128:256] + nisa.nc_matmul( + psum_tile_hi[:num_q, :D_TILE], + exp_tp_tile[:num_kv, nl.ds(mm2_i * V_TILE_SZ, num_q)], + v_sb[v_tile_idx][:num_kv, nl.ds(D_TILE, D_TILE)], + ) + + # Accumulate large tile results into SBUF + if large_tile_idx == 0: + nisa.tensor_copy( + mm2_sb[grp_i][:num_q, :D_TILE], psum_tile_lo[:num_q, :D_TILE] + ) + nisa.tensor_copy( + mm2_sb[grp_i][:num_q, nl.ds(D_TILE, D_TILE)], + psum_tile_hi[:num_q, :D_TILE], + ) + else: + nisa.tensor_tensor( + mm2_sb[grp_i][:num_q, :D_TILE], + mm2_sb[grp_i][:num_q, :D_TILE], + psum_tile_lo[:num_q, :D_TILE], + nl.add, + ) + nisa.tensor_tensor( + mm2_sb[grp_i][:num_q, nl.ds(D_TILE, D_TILE)], + mm2_sb[grp_i][:num_q, nl.ds(D_TILE, D_TILE)], + psum_tile_hi[:num_q, :D_TILE], + nl.add, + ) + + +def _pipe_write_back( + grp_i, + mm2_sb, + exp_partial_sum, + exp_sum_recip, + wb_exp_section_sum, + wb_zero_bias, + wb_o_bf16, + o_hbm, + seqlen_q, + num_exp_tiles, + batch_id, + q_head_idx, +): + """Write-back: normalize by 1/sum(exp), cast to bf16, DMA to HBM.""" + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + + nisa.tensor_reduce( + wb_exp_section_sum[grp_i][:num_q, 0], + nl.add, + exp_partial_sum[grp_i][:num_q, :num_exp_tiles], + axis=1, + ) + + nisa.reciprocal( + exp_sum_recip[grp_i][:num_q, 0], + wb_exp_section_sum[grp_i][:num_q, 0], + ) + + # Scale output and cast to bf16 + nisa.activation( + wb_o_bf16[grp_i][:num_q, :D_HEAD], + nl.copy, + mm2_sb[grp_i][:num_q, :D_HEAD], + scale=exp_sum_recip[grp_i][:num_q, 0], + bias=wb_zero_bias[:num_q], + ) + + # DMA to HBM output + nisa.dma_copy( + dst=o_hbm[batch_id, q_head_idx, q_start : q_start + num_q, 0:D_HEAD], + src=wb_o_bf16[grp_i][:num_q, :D_HEAD], + ) + + +# ============================================================================ +# Main kernel +# ============================================================================ + + +@nki.jit +def flash_attn_d256_swa( + q, + k, + v, + q_h_per_k_h=8, + n_kv_heads=1, + seqlen_q=512, + seqlen_kv=512, + window_size=512, +): + """ + Flash attention for head_dim=256 with sliding window mask. + + Called per (batch, kv_head) pair with pre-sliced post-RoPE tensors. + + Args: + q: (1, q_h_per_k_h, seq_q, 256) -- bfloat16, BHSD, post-RoPE + k: (1, 1, seq_k, 256) -- bfloat16, BHSD, post-RoPE + v: (1, 1, seq_v, 256) -- bfloat16, BHSD, post-v_norm + q_h_per_k_h: Q heads per KV head (8 for E2B full model) + n_kv_heads: must be 1 (kernel processes one KV head at a time) + seqlen_q: sequence length for Q + seqlen_kv: sequence length for K/V + window_size: sliding window size (512 for E2B SWA layers) + + Returns: + o: (1, q_h_per_k_h, seq_q, 256) -- bfloat16, BHSD + """ + d = D_HEAD + n_heads = q_h_per_k_h * n_kv_heads + scale = 1.0 / math.sqrt(d) + + batch_id = 0 + kv_head_id = 0 + + # Output allocation + o = nl.ndarray((1, n_heads, seqlen_q, d), dtype=nl.bfloat16, buffer=nl.shared_hbm) + + num_grps = (seqlen_q + Q_GRP_SZ - 1) // Q_GRP_SZ + num_k_tiles = (seqlen_kv + K_TILE_SZ - 1) // K_TILE_SZ + num_v_tiles = (seqlen_kv + V_TILE_SZ - 1) // V_TILE_SZ + num_large_tiles = (seqlen_kv + LARGE_TILE_SZ - 1) // LARGE_TILE_SZ + num_exp_per_large = LARGE_TILE_SZ // EXP_TILE_SZ + num_exp_tiles = num_large_tiles * num_exp_per_large + + # ========================================================================= + # Buffer Allocation + # ========================================================================= + sca = 0 + + k_sb_lo, sca = _alloc_modular_1d( + (D_TILE, K_TILE_SZ), nl.bfloat16, num_k_tiles, num_k_tiles, sca + ) + k_sb_hi, sca = _alloc_modular_1d( + (D_TILE, K_TILE_SZ), nl.bfloat16, num_k_tiles, num_k_tiles, sca + ) + v_sb, sca = _alloc_modular_1d( + (V_TILE_SZ, D_HEAD), nl.bfloat16, num_v_tiles, num_v_tiles, sca + ) + q_sb_lo, sca = _alloc_modular_1d((D_TILE, Q_GRP_SZ), nl.bfloat16, num_grps, 2, sca) + q_sb_hi, sca = _alloc_modular_1d((D_TILE, Q_GRP_SZ), nl.bfloat16, num_grps, 2, sca) + + # Masking temp buffers + sca = _align32(sca) + mm1_copy_sb = nl.ndarray((Q_GRP_SZ, K_TILE_SZ), dtype=nl.float32, buffer=nl.sbuf) + sca += K_TILE_SZ * 4 + sca = _align32(sca) + mm1_asel_sb = nl.ndarray((Q_GRP_SZ, K_TILE_SZ), dtype=nl.float32, buffer=nl.sbuf) + sca += K_TILE_SZ * 4 + + mm1_masked, sca = _alloc_modular_2d( + (Q_GRP_SZ, LARGE_TILE_SZ), + nl.float32, + num_grps, + num_large_tiles, + 2, + num_large_tiles, + sca, + ) + mm1_partial_max, sca = _alloc_modular_1d( + (Q_GRP_SZ, num_k_tiles), nl.float32, num_grps, 2, sca + ) + mm1_section_max, sca = _alloc_modular_1d( + (Q_GRP_SZ, 1), nl.float32, num_grps, 2, sca + ) + + sca = _align32(sca) + mm1_running_max = nl.ndarray((Q_GRP_SZ, num_grps), dtype=nl.float32, buffer=nl.sbuf) + sca += num_grps * 4 + + exp_sb, sca = _alloc_modular_2d( + (Q_GRP_SZ, LARGE_TILE_SZ), + nl.bfloat16, + num_grps, + num_large_tiles, + 1, + num_large_tiles, + sca, + ) + exp_partial_sum, sca = _alloc_modular_1d( + (Q_GRP_SZ, num_exp_tiles), nl.float32, num_grps, 2, sca + ) + exp_tp_sb, sca = _alloc_modular_3d( + (V_TILE_SZ, K_TILE_SZ), + nl.bfloat16, + (num_grps, num_large_tiles, num_exp_per_large), + (2, num_large_tiles, num_exp_per_large), + sca, + ) + mm2_sb, sca = _alloc_modular_1d((Q_GRP_SZ, D_HEAD), nl.float32, num_grps, 2, sca) + exp_sum_recip, sca = _alloc_modular_1d((Q_GRP_SZ, 1), nl.float32, num_grps, 2, sca) + + wb_exp_section_sum, sca = _alloc_modular_1d( + (Q_GRP_SZ, 1), nl.float32, num_grps, 2, sca + ) + sca = _align32(sca) + wb_zero_bias = nl.ndarray((Q_GRP_SZ, 1), dtype=nl.float32, buffer=nl.sbuf) + sca += 1 * 4 + wb_o_bf16, sca = _alloc_modular_1d( + (Q_GRP_SZ, D_HEAD), nl.bfloat16, num_grps, 2, sca + ) + + # ========================================================================= + # GQA outer loop + # ========================================================================= + for i_q_h in range(q_h_per_k_h): + q_head_idx = kv_head_id * q_h_per_k_h + i_q_h + + # PSUM allocations (per GQA iteration) + mm1_psum = [] + for grp_idx in range(num_grps): + grp_row = [] + for lt_idx in range(num_large_tiles): + tile_row = [] + for kt_idx in range(4): + tile_row.append( + nl.ndarray( + (Q_GRP_SZ, PSUM_FMAX), dtype=nl.float32, buffer=nl.psum + ) + ) + grp_row.append(tile_row) + mm1_psum.append(grp_row) + + mm2_psum_lo = [] + for grp_idx in range(num_grps): + grp_row = [] + for lt_idx in range(num_large_tiles): + grp_row.append( + nl.ndarray((Q_GRP_SZ, D_TILE), dtype=nl.float32, buffer=nl.psum) + ) + mm2_psum_lo.append(grp_row) + + mm2_psum_hi = [] + for grp_idx in range(num_grps): + grp_row = [] + for lt_idx in range(num_large_tiles): + grp_row.append( + nl.ndarray((Q_GRP_SZ, D_TILE), dtype=nl.float32, buffer=nl.psum) + ) + mm2_psum_hi.append(grp_row) + + # Load K and V (shared across Q heads in GQA) + for k_idx in nl.affine_range(num_k_tiles): + k_start = k_idx * K_TILE_SZ + num_k = min(seqlen_kv - k_start, K_TILE_SZ) + k_offset = ( + batch_id * n_kv_heads * seqlen_kv * d + + kv_head_id * seqlen_kv * d + + k_start * d + ) + nisa.dma_transpose( + dst=k_sb_lo[k_idx].ap( + [[K_TILE_SZ, D_TILE], [1, 1], [1, 1], [1, num_k]] + ), + src=k.ap([[d, num_k], [1, 1], [1, 1], [1, D_TILE]], offset=k_offset), + ) + nisa.dma_transpose( + dst=k_sb_hi[k_idx].ap( + [[K_TILE_SZ, D_TILE], [1, 1], [1, 1], [1, num_k]] + ), + src=k.ap( + [[d, num_k], [1, 1], [1, 1], [1, D_TILE]], offset=k_offset + D_TILE + ), + ) + + for v_idx in nl.affine_range(num_v_tiles): + v_start = v_idx * V_TILE_SZ + num_v = min(seqlen_kv - v_start, V_TILE_SZ) + nisa.dma_copy( + dst=v_sb[v_idx][:num_v, :D_HEAD], + src=v[batch_id, kv_head_id, v_start : v_start + num_v, 0:D_HEAD], + ) + + nisa.memset(wb_zero_bias, value=0.0) + + # ===================================================================== + # Sequential execution (no pipelining for initial version — simpler, easier to debug) + # Full pipeline can be added once correctness is validated. + # ===================================================================== + for grp_i in range(num_grps): + _pipe_load_q( + grp_i, + q_sb_lo, + q_sb_hi, + q, + D_TILE, + seqlen_q, + batch_id, + q_head_idx, + n_heads, + d, + ) + + _pipe_qk_and_max( + grp_i, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + window_size, + ) + + _pipe_update_max( + grp_i, + mm1_partial_max, + mm1_section_max, + mm1_running_max, + num_k_tiles, + seqlen_q, + ) + + _pipe_exp( + grp_i, + mm1_masked, + mm1_running_max, + exp_sb, + exp_partial_sum, + exp_tp_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_k_tiles, + ) + + _pipe_pv( + grp_i, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_v_tiles, + ) + + _pipe_write_back( + grp_i, + mm2_sb, + exp_partial_sum, + exp_sum_recip, + wb_exp_section_sum, + wb_zero_bias, + wb_o_bf16, + o, + seqlen_q, + num_exp_tiles, + batch_id, + q_head_idx, + ) + + return o + + +# ============================================================================ +# Unit test +# ============================================================================ +if __name__ == "__main__": + import torch + import torch.nn.functional as F + import time + + def reference_swa_attention(q, k, v, window_size): + """CPU reference: sliding window attention. + q(b,h,sq,d), k(b,h,sk,d), v(b,h,sk,d) -> (b,h,sq,d) + """ + d = q.shape[3] + q_t = q.float() + k_t = k.float() + v_t = v.float() + scale = 1.0 / (d**0.5) + attn = q_t @ k_t.transpose(-2, -1) * scale + + sq, sk = q_t.shape[2], k_t.shape[2] + # Sliding window mask: position i attends to [i - window_size + 1, i] + row_idx = torch.arange(sq).unsqueeze(1) # (sq, 1) + col_idx = torch.arange(sk).unsqueeze(0) # (1, sk) + # Mask out: future (col > row) OR too far past (col < row - window + 1) + mask = (col_idx > row_idx) | (col_idx < row_idx - window_size + 1) + attn = attn.masked_fill(mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + return attn @ v_t + + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + + print("=" * 70) + print("Flash Attention d=256 with Sliding Window Mask (E2B SWA)") + print("=" * 70) + + tests = [ + { + "seq": 512, + "heads": 1, + "kv_heads": 1, + "window": 512, + "label": "seq=512 w=512 1:1", + }, + { + "seq": 1024, + "heads": 1, + "kv_heads": 1, + "window": 512, + "label": "seq=1024 w=512 1:1", + }, + { + "seq": 1024, + "heads": 8, + "kv_heads": 1, + "window": 512, + "label": "seq=1024 w=512 GQA 8:1", + }, + { + "seq": 512, + "heads": 8, + "kv_heads": 1, + "window": 512, + "label": "seq=512 w=512 GQA 8:1", + }, + ] + + for t in tests: + seq_len = t["seq"] + heads = t["heads"] + kv_heads = t["kv_heads"] + window = t["window"] + d = 256 + print(f"\n=== Testing: {t['label']} ===") + torch.manual_seed(42) + q = torch.randn(1, heads, seq_len, d, dtype=torch.bfloat16) + k = torch.randn(1, kv_heads, seq_len, d, dtype=torch.bfloat16) + v = torch.randn(1, kv_heads, seq_len, d, dtype=torch.bfloat16) + + # CPU reference with GQA expansion + ref_parts = [] + for h_idx in range(heads): + kv_idx = h_idx // (heads // kv_heads) + ref_h = reference_swa_attention( + q[:, h_idx : h_idx + 1], + k[:, kv_idx : kv_idx + 1], + v[:, kv_idx : kv_idx + 1], + window, + ) + ref_parts.append(ref_h) + ref = torch.cat(ref_parts, dim=1) + + # Run kernel + q_dev = q.to(device) + k_dev = k.to(device) + v_dev = v.to(device) + + t0 = time.time() + q_h_per_kv = heads // kv_heads + out_parts = [] + for kv_h in range(kv_heads): + q_slice = q_dev[:, kv_h * q_h_per_kv : (kv_h + 1) * q_h_per_kv, :, :] + k_slice = k_dev[:, kv_h : kv_h + 1, :, :] + v_slice = v_dev[:, kv_h : kv_h + 1, :, :] + o_part = flash_attn_d256_swa( + q_slice, + k_slice, + v_slice, + q_h_per_k_h=q_h_per_kv, + n_kv_heads=1, + seqlen_q=seq_len, + seqlen_kv=seq_len, + window_size=window, + ) + out_parts.append(o_part) + out = torch.cat(out_parts, dim=1) + xm.mark_step() + out_cpu = out.cpu().float() + t1 = time.time() + + cos_sim = F.cosine_similarity( + ref.reshape(-1).unsqueeze(0), out_cpu.reshape(-1).unsqueeze(0) + ).item() + maxd = (ref - out_cpu).abs().max().item() + print(f" Time: {t1 - t0:.1f}s (includes compile)") + print(f" Cosine sim: {cos_sim:.6f}") + print(f" Max diff: {maxd:.6f}") + print(f" {'PASS' if cos_sim > 0.999 else 'FAIL'}") diff --git a/contrib/models/gemma-4-26b-a4b-it/src/nki_flash_attn_large_d.py b/contrib/models/gemma-4-26b-a4b-it/src/nki_flash_attn_large_d.py new file mode 100644 index 00000000..247e16f8 --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/src/nki_flash_attn_large_d.py @@ -0,0 +1,346 @@ +""" +NKI Flash Attention kernel supporting head_dim > 128 (up to 512). + +Designed for NxDI integration: accepts the same I/O layout as NxDI's +standard flash attention kernel path (tp_q=True, tp_k=True). + +Input layout: + Q: (B*H, seqlen, d) -- tp_q=True layout + K: (B*H_kv, seqlen, d) -- tp_k=True layout + V: (B*H_kv, seqlen, d) + +Output layout: + O: (B*H, d, seqlen) -- tp_out=True (what NxDI expects from kernel) + +The kernel tiles the QK matmul contraction dimension in chunks of 128 +(the hardware partition axis max). For d=256: 2 chunks, for d=512: 4 chunks. +The PV matmul places d on the free axis (max 512), supporting up to d=512. + +Supports: + - Causal masking + - GQA (B*H > B*H_kv, with B*H divisible by B*H_kv) + - Sliding window attention (optional) + - head_dim = 128, 256, or 512 + +Based on the proven nki_flash_attn_d256.py kernel (cosine > 0.999 validated). +""" + +import math +import numpy as np +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.language as nl +from neuronxcc import nki +from neuronxcc.nki.language import par_dim + +B_P = 128 # partition dim max (nl.tile_size.pmax) +B_F = 512 # free dim max for matmul moving operand (nl.tile_size.gemm_moving_fmax) +D_TILE = 128 # head_dim tile size for QK contraction +NEG_INF = -9984.0 # bfloat16-safe negative infinity + + +@nki.jit +def flash_attn_large_d( + q, + k, + v, + scale: float = 1.0, + use_causal_mask: bool = True, + sliding_window: int = 0, +): + """ + Flash attention for head_dim up to 512. + + Args: + q: (bs, seqlen_q, d) -- bfloat16, tp_q=True layout (B*H merged into batch) + k: (bs_kv, seqlen_k, d) -- bfloat16, tp_k=True layout (B*H_kv merged) + v: (bs_kv, seqlen_k, d) -- bfloat16 + scale: float, scaling factor (already applied to Q by NxDI: Q = Q / sqrt(d)) + use_causal_mask: bool + sliding_window: int, 0 means no sliding window + + Returns: + o: (bs, d, seqlen_q) -- bfloat16, tp_out=True (BHDS after unmerge) + """ + bs, seqlen_q, d = q.shape + bs_kv, seqlen_k, _ = k.shape + + assert d <= 512, f"head_dim must be <= 512, got {d}" + assert d % D_TILE == 0, f"head_dim must be divisible by {D_TILE}, got {d}" + assert seqlen_q % B_P == 0, f"seqlen_q must be divisible by {B_P}, got {seqlen_q}" + assert seqlen_k % B_F == 0 or seqlen_k % B_P == 0, ( + f"seqlen_k must be divisible by {B_F} or {B_P}, got {seqlen_k}" + ) + + num_d_chunks = d // D_TILE # 1 for d=128, 2 for d=256, 4 for d=512 + q_h_per_k_h = bs // bs_kv # GQA ratio + + # Output: (bs, d, seqlen_q) -- transposed layout for NxDI + o = nl.ndarray((bs, d, seqlen_q), dtype=q.dtype, buffer=nl.shared_hbm) + + batch_id = nl.program_id(axis=0) + + n_q_tiles = seqlen_q // B_P + # K/V tiles: use B_F if possible, else B_P + kv_tile_size = B_F if seqlen_k % B_F == 0 else B_P + n_kv_tiles = seqlen_k // kv_tile_size + + for i_q_h in nl.affine_range(q_h_per_k_h): + q_batch = batch_id * q_h_per_k_h + i_q_h + k_batch = batch_id + + for qi in nl.sequential_range(n_q_tiles): + # Accumulators + o_acc = nl.zeros((par_dim(B_P), d), dtype=np.float32, buffer=nl.sbuf) + m_acc = nl.full((par_dim(B_P), 1), fill_value=NEG_INF, dtype=np.float32) + l_acc = nl.full((par_dim(B_P), 1), fill_value=NEG_INF, dtype=np.float32) + + # Load Q tile: num_d_chunks chunks of (D_TILE, B_P) + # Q is (bs, seqlen_q, d), we need it transposed to (d_chunk, seqlen_chunk) + # for the QK matmul: Q^T @ K where Q is (D_TILE, B_P) contraction on D_TILE + q_chunks = nl.ndarray( + (num_d_chunks, par_dim(D_TILE), B_P), dtype=nl.bfloat16 + ) + for dc in nl.affine_range(num_d_chunks): + # Load from (bs, seqlen, d) -> need to transpose (seqlen_tile, d_chunk) + # to (d_chunk, seqlen_tile) for matmul contraction + q_tile_raw = nl.ndarray((par_dim(B_P), D_TILE), dtype=nl.bfloat16) + q_tile_raw[:, :] = nl.load( + q[q_batch, nl.ds(qi * B_P, B_P), nl.ds(dc * D_TILE, D_TILE)] + ) + # Transpose: (B_P, D_TILE) -> (D_TILE, B_P) via nc_transpose + # But nc_transpose output goes to PSUM, need to copy to SBUF + q_t_psum = nl.ndarray( + (par_dim(D_TILE), B_P), dtype=np.float32, buffer=nl.psum + ) + q_t_psum[:, :] = nisa.nc_transpose(q_tile_raw) + q_chunks[dc, :, :] = nl.copy(q_t_psum, dtype=nl.bfloat16) + + # Scale Q (scale is already applied by NxDI, but if scale != 1.0) + if scale != 1.0: + for dc in nl.affine_range(num_d_chunks): + q_chunks[dc, :, :] = nl.multiply(q_chunks[dc], scale) + + for kvi in nl.sequential_range(n_kv_tiles): + kv_start = kvi * kv_tile_size + num_kv = kv_tile_size # actual KV tokens in this tile + + # Causal mask: skip if Q tile is entirely after K tile (no Q can attend to K) + if use_causal_mask: + q_end = (qi + 1) * B_P - 1 + skip_condition = q_end < kv_start + else: + skip_condition = False + + # Sliding window: skip if K tile is entirely before the window + # of the FIRST Q position in this tile + if sliding_window > 0 and use_causal_mask: + q_start = qi * B_P + kv_end = kv_start + kv_tile_size - 1 + # Window for q_start covers [q_start - sw + 1, q_start] + # Skip if kv_end < q_start - sw + 1 + skip_sw = kv_end < (q_start - sliding_window + 1) + skip_condition = skip_condition or skip_sw + + if not skip_condition: + # Load K tile: (num_d_chunks, par_dim(D_TILE), kv_tile_size) + # K is (bs_kv, seqlen_k, d), need transposed to (d_chunk, seqlen_chunk) + k_chunks = nl.ndarray( + (num_d_chunks, par_dim(D_TILE), kv_tile_size), + dtype=nl.bfloat16, + ) + for dc in nl.affine_range(num_d_chunks): + if kv_tile_size <= B_P: + # Small tile: load (B_P, D_TILE) and transpose + k_raw = nl.ndarray( + (par_dim(kv_tile_size), D_TILE), dtype=nl.bfloat16 + ) + k_raw[:, :] = nl.load( + k[ + k_batch, + nl.ds(kv_start, kv_tile_size), + nl.ds(dc * D_TILE, D_TILE), + ] + ) + k_t_psum = nl.ndarray( + (par_dim(D_TILE), kv_tile_size), + dtype=np.float32, + buffer=nl.psum, + ) + k_t_psum[:, :] = nisa.nc_transpose(k_raw) + k_chunks[dc, :, :] = nl.copy(k_t_psum, dtype=nl.bfloat16) + else: + # Large tile (B_F=512): load in sub-tiles of B_P and transpose each + n_sub = kv_tile_size // B_P + for si in nl.affine_range(n_sub): + k_raw = nl.ndarray( + (par_dim(B_P), D_TILE), dtype=nl.bfloat16 + ) + k_raw[:, :] = nl.load( + k[ + k_batch, + nl.ds(kv_start + si * B_P, B_P), + nl.ds(dc * D_TILE, D_TILE), + ] + ) + k_t_psum = nl.ndarray( + (par_dim(D_TILE), B_P), + dtype=np.float32, + buffer=nl.psum, + ) + k_t_psum[:, :] = nisa.nc_transpose(k_raw) + k_chunks[dc, :, nl.ds(si * B_P, B_P)] = nl.copy( + k_t_psum, dtype=nl.bfloat16 + ) + + # Tiled QK matmul: accumulate over d-chunks + # Each chunk: (D_TILE, B_P)^T @ (D_TILE, kv_tile_size) + # = (B_P, kv_tile_size) + qk = nl.ndarray( + (par_dim(B_P), kv_tile_size), + dtype=np.float32, + buffer=nl.psum, + ) + qk[:, :] = nl.matmul(q_chunks[0], k_chunks[0], transpose_x=True) + for dc in nl.affine_range(num_d_chunks - 1): + qk[:, :] += nl.matmul( + q_chunks[dc + 1], k_chunks[dc + 1], transpose_x=True + ) + + # Move to SBUF for masking/softmax + qk_sbuf = nl.ndarray( + (par_dim(B_P), kv_tile_size), dtype=np.float32, buffer=nl.sbuf + ) + + # Apply causal mask (and sliding window if enabled) + if use_causal_mask: + i_q, i_k = nl.mgrid[0:B_P, 0:kv_tile_size] + q_pos = qi * B_P + i_q + k_pos = kv_start + i_k + pred_causal = q_pos >= k_pos + + qk_sbuf[:, :] = nisa.affine_select( + pred=pred_causal, + on_true_tile=qk, + on_false_value=NEG_INF, + dtype=np.float32, + ) + + if sliding_window > 0: + # Apply sliding window mask on top of causal mask + pred_sw = (q_pos - k_pos) < sliding_window + qk_sw = nl.ndarray( + (par_dim(B_P), kv_tile_size), + dtype=np.float32, + buffer=nl.sbuf, + ) + qk_sw[:, :] = nisa.affine_select( + pred=pred_sw, + on_true_tile=qk_sbuf, + on_false_value=NEG_INF, + dtype=np.float32, + ) + qk_sbuf = qk_sw + else: + qk_sbuf[:, :] = nl.copy(qk, dtype=np.float32) + + # Row max for online softmax + new_max = nisa.tensor_reduce( + np.max, qk_sbuf, axis=(1,), dtype=np.float32, negate=False + ) + + m_prev = nl.copy(m_acc[:, 0]) + m_acc[:, 0] = nl.maximum(m_prev, new_max) + m_cur = m_acc[:, 0] + + # Rescale previous output + alpha = nisa.activation(np.exp, m_cur, bias=m_prev, scale=-1.0) + o_acc[...] = nl.multiply(o_acc, alpha) + + # exp(qk - max) and row sum + p = nl.ndarray((par_dim(B_P), kv_tile_size), dtype=nl.bfloat16) + p_sum = nl.ndarray((par_dim(B_P), 1), dtype=np.float32) + p[:, :] = nisa.activation_reduce( + np.exp, + qk_sbuf, + bias=-1 * m_cur, + scale=1.0, + reduce_op=nl.add, + reduce_res=p_sum[:, 0], + dtype=nl.bfloat16, + ) + + # Load V tile: (kv_tile_size // B_P, par_dim(B_P), d) + n_v_sub = kv_tile_size // B_P + v_tile = nl.ndarray((n_v_sub, par_dim(B_P), d), dtype=nl.bfloat16) + for vi in nl.affine_range(n_v_sub): + v_tile[vi, :, :] = nl.load( + v[k_batch, nl.ds(kv_start + vi * B_P, B_P), :], + dtype=nl.bfloat16, + ) + + # Transpose p for PV matmul: need p as (par_dim, kv_tile_size) + # in the right layout for contraction + p_t = nl.ndarray((par_dim(B_P), kv_tile_size), dtype=nl.bfloat16) + for ti in nl.affine_range(kv_tile_size // B_P): + p_t_psum = nl.ndarray( + (par_dim(B_P), B_P), + dtype=np.float32, + buffer=nl.psum, + ) + p_t_psum[:, :] = nisa.nc_transpose(p[:, nl.ds(ti * B_P, B_P)]) + p_t[:, nl.ds(ti * B_P, B_P)] = nl.copy( + p_t_psum, dtype=nl.bfloat16 + ) + + # PV matmul: (B_P, kv_tile_size) @ (kv_tile_size, d) -> (B_P, d) + # d is on the free axis of PSUM (max 512) + pv = nl.zeros( + (par_dim(B_P), d), + dtype=np.float32, + buffer=nl.psum, + lazy_initialization=True, + ) + for vi in nl.affine_range(n_v_sub): + pv[:, :] += nl.matmul( + p_t[:, nl.ds(vi * B_P, B_P)], + v_tile[vi, :, :], + transpose_x=True, + ) + + o_acc[:, :] = nl.add(o_acc, pv) + + # Update log-sum-exp + exp_l = nisa.activation(nl.exp, m_cur, bias=l_acc[:, 0], scale=-1.0) + l_acc[:, 0] = nl.add( + m_cur, nisa.activation(nl.log, exp_l, bias=p_sum[:, 0]) + ) + + # Final rescale + final_exp = nisa.activation( + np.exp, l_acc[:, 0], bias=m_acc[:, 0], scale=-1.0 + ) + out = nl.multiply(o_acc, final_exp, dtype=nl.bfloat16) + + # Store output: (B_P, d) -> o[batch, d, seqlen] (transposed) + # We need to write to (bs, d, seqlen_q) layout + # out is (par_dim(B_P), d), we need (d, B_P) in HBM + # Transpose: (B_P, d) -> (d, B_P) via nc_transpose in chunks + # d can be up to 512, B_P is 128 + # nc_transpose takes (par_dim(P), F) -> (par_dim(F), P) but F <= 512 + # We need d on par_dim for the output, but d can be > 128 + # Instead, store in d-chunks: each chunk is (B_P, D_TILE) + # Transpose each to (D_TILE, B_P) and store + for dc in nl.affine_range(num_d_chunks): + out_chunk = out[:, nl.ds(dc * D_TILE, D_TILE)] + out_t_psum = nl.ndarray( + (par_dim(D_TILE), B_P), dtype=np.float32, buffer=nl.psum + ) + out_t_psum[:, :] = nisa.nc_transpose(out_chunk) + out_t = nl.ndarray((par_dim(D_TILE), B_P), dtype=nl.bfloat16) + out_t[:, :] = nl.copy(out_t_psum, dtype=nl.bfloat16) + nl.store( + o[q_batch, nl.ds(dc * D_TILE, D_TILE), nl.ds(qi * B_P, B_P)], + out_t, + ) + + return o diff --git a/contrib/models/gemma-4-26b-a4b-it/test/__init__.py b/contrib/models/gemma-4-26b-a4b-it/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/gemma-4-26b-a4b-it/test/integration/__init__.py b/contrib/models/gemma-4-26b-a4b-it/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/gemma-4-26b-a4b-it/test/integration/test_model.py b/contrib/models/gemma-4-26b-a4b-it/test/integration/test_model.py new file mode 100644 index 00000000..56349e52 --- /dev/null +++ b/contrib/models/gemma-4-26b-a4b-it/test/integration/test_model.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +""" +Integration smoke for Gemma-4-26B-A4B-it NeuronX Distributed Inference port. + +This is a Stage-1 / Stage-2 / Stage-3 smoke runner mirrored on PR #106's +test layout. It is invoked via the helper scripts under ``scripts/`` — +the body here is the same flow as those scripts but importable. + +Usage: + # Stage 1 — dense path only (fast) + GEMMA4_DISABLE_MOE=1 PYTHONPATH=src \ + python test/integration/test_model.py compile + + # Stage 2 — MoE on + PYTHONPATH=src python test/integration/test_model.py compile + + # Stage 3 — generate + PYTHONPATH=src python test/integration/test_model.py generate +""" + +import json +import os +import sys +import time +from pathlib import Path + +import torch + +# Apply NxDI runtime patches (NKI kernel for d>128, get_last_kv_window fix). +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +import ndxi_patch # noqa: E402 + +ndxi_patch.apply_patch() + +from modeling_gemma4_neuron import ( # noqa: E402 + Gemma4InferenceConfig, + Gemma4NeuronConfig, + NeuronGemma4ForCausalLM, +) + + +MODEL_PATH = os.environ.get("GEMMA4_MODEL_PATH", "/home/ubuntu/gemma4-26b-a4b") +COMPILED_PATH = os.environ.get("GEMMA4_COMPILED_PATH", "/home/ubuntu/gemma4-compiled") +TP_DEGREE = int(os.environ.get("GEMMA4_TP_DEGREE", "8")) +BATCH_SIZE = int(os.environ.get("GEMMA4_BATCH_SIZE", "1")) +SEQ_LEN = int(os.environ.get("GEMMA4_SEQ_LEN", "256")) +MAX_NEW_TOKENS = int(os.environ.get("GEMMA4_MAX_NEW_TOKENS", "8")) +PROMPT = os.environ.get("GEMMA4_PROMPT", "Hello, my name is") +MOE_EP_DEGREE = int(os.environ.get("GEMMA4_MOE_EP_DEGREE", "1")) +MOE_TP_DEGREE = int(os.environ.get("GEMMA4_MOE_TP_DEGREE", str(TP_DEGREE))) + + +def create_config(model_path: str) -> Gemma4InferenceConfig: + neuron_config = Gemma4NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=BATCH_SIZE, + max_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + on_device_sampling_config=None, + torch_dtype=torch.bfloat16, + fused_qkv=False, + attn_kernel_enabled=False, + moe_ep_degree=MOE_EP_DEGREE, + moe_tp_degree=MOE_TP_DEGREE, + glu_mlp=True, + glu_type="glu", + router_act_fn="softmax", + router_dtype="float32", + disable_normalize_top_k_affinities=True, + ) + + def load_config_fn(config_obj): + config_path = os.path.join(model_path, "config.json") + with open(config_path) as f: + config_dict = json.load(f) + for k, v in config_dict.items(): + setattr(config_obj, k, v) + + cfg = Gemma4InferenceConfig( + neuron_config=neuron_config, load_config=load_config_fn + ) + if os.environ.get("GEMMA4_DISABLE_MOE", "0") == "1": + cfg.disable_moe_for_smoke_compile = True + return cfg + + +def cmd_compile() -> int: + if not Path(MODEL_PATH).exists(): + print(f"ERROR: model path {MODEL_PATH} does not exist", file=sys.stderr) + return 1 + config = create_config(MODEL_PATH) + print( + f"Config: hidden_size={config.hidden_size}, " + f"num_layers={config.num_hidden_layers}, " + f"num_experts={getattr(config, 'num_experts', None)}, " + f"top_k={getattr(config, 'top_k_experts', None)}" + ) + t0 = time.perf_counter() + model = NeuronGemma4ForCausalLM(MODEL_PATH, config) + model.compile(COMPILED_PATH) + print(f"Compile finished in {(time.perf_counter() - t0)/60:.1f} min") + model = NeuronGemma4ForCausalLM(MODEL_PATH, config) + model.load(COMPILED_PATH) + print("Smoke compile + load OK") + return 0 + + +def cmd_generate() -> int: + # Defer to scripts/smoke_inference.py — this entrypoint is just a + # pytest-friendly thin wrapper. The smoke_inference.py script in + # scripts/ handles tokenizer fallback for gemma-4 special-tokens. + from subprocess import run + + here = Path(__file__).resolve().parent.parent.parent + smoke = here / "scripts" / "smoke_inference.py" + return run([sys.executable, str(smoke)]).returncode + + +def main(argv) -> int: + if len(argv) < 2 or argv[1] not in {"compile", "generate"}: + print("usage: test_model.py {compile|generate}", file=sys.stderr) + return 2 + return cmd_compile() if argv[1] == "compile" else cmd_generate() + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/contrib/models/gemma-4-26b-a4b-it/test/unit/__init__.py b/contrib/models/gemma-4-26b-a4b-it/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b From 905bf86edb889509b9955f0bb436240ebe092c69 Mon Sep 17 00:00:00 2001 From: xniwang Date: Wed, 3 Jun 2026 07:07:52 +0000 Subject: [PATCH 2/2] README: add Stage 5 canonical Gemma-4 validation results 11/12 first-16-token match at 100% vs HF CPU bf16 (Gemma4ForConditionalGeneration, transformers 5.10.0.dev0) using processor.apply_chat_template (enable_thinking=False and =True) and processor.parse_response. The lone 87.5% (capital/thinking-on/sample) is sampling-RNG divergence; greedy at the same setting is 100%. enable_thinking=True exercises the multi-channel response path through MoE routing -- backends produce identical greedy tokens and identical parse_response {role,thinking,content} dicts. --- contrib/models/gemma-4-26b-a4b-it/README.md | 47 +++++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/contrib/models/gemma-4-26b-a4b-it/README.md b/contrib/models/gemma-4-26b-a4b-it/README.md index 5218c871..ee14f3c1 100644 --- a/contrib/models/gemma-4-26b-a4b-it/README.md +++ b/contrib/models/gemma-4-26b-a4b-it/README.md @@ -82,6 +82,46 @@ MoE compile requires `--internal-hlo2tensorizer-options='--verify-hlo=false'` | Throughput | 114 tok/s | | Status | **PASS** (coherence smoke; greedy + base-style continuation, no chat template) | +### Stage 5 — Canonical Gemma-4 chat validation (added 2026-06-03) + +Replaces the Stage 3 "smoke only" caveat with full canonical validation +following the official HF Gemma-4 pattern: `processor.apply_chat_template` +(both `enable_thinking=False` and `=True`) plus `processor.parse_response`. + +Compared the Trainium 2 port head-to-head against `Gemma4ForConditionalGeneration` +on CPU bf16 (transformers 5.10.0.dev0). 3 prompts × {greedy, sampled} × {thinking off, on}. + +| prompt | thinking | greedy | sampled | +|---|---|---|---| +| `Write a short joke about saving RAM.` | off | **16/16 (100%)** | 16/16 (100%) | +| `Write a short joke about saving RAM.` | on | **16/16 (100%)** | 16/16 (100%) | +| `What is the capital of France?` | off | **9/9 (100%)** EOS | 9/9 (100%) EOS | +| `What is the capital of France?` | on | **16/16 (100%)** | 14/16 (87.5%) | +| `Explain quantum entanglement in two sentences.` | off | **16/16 (100%)** | 16/16 (100%) | +| `Explain quantum entanglement in two sentences.` | on | **16/16 (100%)** | 16/16 (100%) | + +Match = first-16-tokens equal vs HF CPU bf16 reference. **11/12 cases at +100%; the lone 87.5% is sampling RNG divergence (different framework on +each backend with the same seed) — greedy at the same setup is 100%.** + +`enable_thinking=True` exercises the full multi-channel response path +(`<|channel>thought\n...`) through the MoE router. Both backends +emit identical tokens and `parse_response` returns the same +`{role, thinking, content}` dict, e.g.: + +```python +{'role': 'assistant', + 'thinking': 'The user is asking for the capital of France.\n' + 'The capital of France is Paris.\nState the answer clearly.', + 'content': 'The capital of France is **Paris**.'} +``` + +Latency on Trainium 2 (TP=8, BF16, seq=256): TTFT ~303 ms, TPOT ~8.3 ms +(~120 tok/s greedy). See +[`agent_artifacts/round4/STAGE5_CANONICAL_VALIDATION.md`](https://github.com/xniwangaws/NeuronStuff/blob/main/gemma4-port-26b-a4b/agent_artifacts/round4/STAGE5_CANONICAL_VALIDATION.md) +on the upstream reference repo for full output, raw JSON, and +comparator script. + ## What was reused from existing NxDI - `NeuronAttentionBase` — Q/K/V/o projections, KV cache, GQA sharding, mask @@ -124,9 +164,10 @@ MoE compile requires `--internal-hlo2tensorizer-options='--verify-hlo=false'` ## Open issues / known limitations -- **Smoke-test only**: greedy + no chat template ⇒ Stage 3 output repeats. - Token-match accuracy vs HF reference still pending (sampling + - chat-formatted prompts). +- **Validated as of Stage 5** (2026-06-03): canonical chat (`apply_chat_template` + `parse_response`, + including `enable_thinking={False,True}`) matches HF CPU bf16 at 100% token agreement + for 11/12 greedy/sample combos (the 12th is sampling RNG divergence; greedy is 100%). + See Stage 5 above. - **AutoTokenizer fix**: HF transformers ≤ 4.45 trips on gemma-4's special-tokens list-vs-dict shape; `scripts/smoke_inference.py` falls back to the raw `tokenizers` Rust backend.