diff --git a/examples/models/gpt_oss/README.md b/examples/models/gpt_oss/README.md new file mode 100644 index 0000000..c00697d --- /dev/null +++ b/examples/models/gpt_oss/README.md @@ -0,0 +1,76 @@ +# gpt-oss on Trainium + +A clean implementation of OpenAI's gpt-oss MoE models (e.g., `gpt-oss-20b`) for +AWS Trainium, built on NKIPy. + +## Setup + +``` sh +cd nkipy +uv sync --all-groups +source .venv/bin/activate +cd examples/models/gpt_oss +``` + +## Quickstart + +`test.sh` handles weight preparation and runs a generation end-to-end: + +``` sh +./test.sh +``` + +Or run generation directly (assumes weights are already prepared): + +``` sh +WEIGHTS=./tmp_gpt-oss-20b +TP=4 + +torchrun --nproc-per-node $TP gpt_oss.py \ + -n 500 --checkpoint $WEIGHTS --model openai/gpt-oss-20b \ + "The capital of France is" +``` + +You can point `--model` at a local checkpoint directory too. + +## Weight preparation + +gpt-oss ships its experts **MXFP4-quantized** (`*_blocks` / `*_scales`). The prep +step dequantizes them to bf16 so the NKI kernels run purely in bf16, and it +shards every tensor for tensor parallelism: + +``` sh +python tensor_preparation.py \ + --model-name openai/gpt-oss-20b \ + --world-size 4 \ + --output-dir ./tmp_gpt-oss-20b +``` + +This writes `shard_{rank}.safetensors` files. Dequantized bf16 weights are +larger than the packed checkpoint (~40 GB total), so make sure you have disk +headroom. + +## Architecture notes + +gpt-oss differs from the Qwen3 MoE example in several ways, all handled here: + +| Feature | Handling | +|---|---| +| MXFP4 experts | Dequantized to bf16 at prep time (`tensor_preparation.py`) | +| Interleaved gate/up | De-interleaved to `[gate \| up]` at prep time | +| Clamped SwiGLU | `(up+1) * gate*sigmoid(alpha*gate)` with `clamp(limit=7)` (`kernels/feedforward.py`) | +| Attention sinks | Per-head sink logit concatenated into softmax, then dropped (`kernels/attention.py`) | +| QKV / output bias | Carried through prep and added in the attention kernel | +| Sliding-window attention | Alternating sliding (window=128) / full causal layers; one kernel compiled per attention type | +| YaRN RoPE | `inv_freq` + attention-scaling precomputed from HF config (`config.py`) and baked into the cos/sin cache | +| Router | top-k on raw logits (+bias), then softmax over the selected logits | + +## Files + +| File | Purpose | +|---|---| +| `gpt_oss.py` | Model definition (`GptOssModel`) and text generation | +| `config.py` | Model configuration (incl. YaRN RoPE precompute) | +| `tensor_preparation.py` | Dequantize, reshape, and shard HF weights for TP | +| `test.sh` | Smoke test: prepares weights and runs generation | +| `kernels/` | Attention, feed-forward, RoPE, RMSNorm, softmax, sampling kernels | diff --git a/examples/models/gpt_oss/__init__.py b/examples/models/gpt_oss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/models/gpt_oss/config.py b/examples/models/gpt_oss/config.py new file mode 100644 index 0000000..bc04428 --- /dev/null +++ b/examples/models/gpt_oss/config.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass + +import numpy as np +import torch.distributed as dist +from neuronxcc.nki.language import bfloat16 +from transformers import AutoConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + +# to control compiler_args +DTYPE = bfloat16 + + +@dataclass +class Config: + hidden_size: int + num_heads: int + head_dim: int + num_kv_heads: int + num_layers: int + num_experts_per_tok: int + num_experts: int + # RoPE (YaRN) inverse frequencies and post-scaling, precomputed from HF. + rope_inv_freq: np.ndarray + rope_attention_scaling: float + # Per-layer attention type: "sliding_attention" or "full_attention". + layer_types: list + sliding_window: int + # Clamped-SwiGLU parameters (gpt-oss specific). + swiglu_alpha: float = 1.702 + swiglu_limit: float = 7.0 + context_len: int = None + max_new_tokens: int = None + max_batch_size: int = 1 + norm_eps: float = 1e-5 + intermediate_size: int = 2880 + max_seq_len: int = 4096 + dtype: np.dtype = DTYPE + additional_compiler_args_nkipy: str = "--lnc 1" + # Decoder-layer indices whose outputs are tapped for EAGLE-3 speculative + # decoding. None disables capture (the default, non-speculative path). + aux_layers: tuple = None + + def is_sliding(self, layer_id: int) -> bool: + return self.layer_types[layer_id] == "sliding_attention" + + @staticmethod + def default_aux_layers(num_layers: int) -> tuple: + """EAGLE-3's standard low/mid/high decoder-layer taps (vLLM convention).""" + return (2, num_layers // 2, num_layers - 3) + + @staticmethod + def peagle_aux_layers(num_layers: int) -> tuple: + """P-EAGLE's tap layers (0-indexed, captures before each layer).""" + return (0, num_layers // 2, num_layers - 1) + + +def get_config(model_name, context_len, max_new_tokens): + hf_config = AutoConfig.from_pretrained(model_name) + + # YaRN RoPE: precompute inverse frequencies + attention scaling factor once. + # These are constants (independent of runtime tensors), so we bake them into + # the kernel's cos/sin cache at compile time. + rope_init_fn = ROPE_INIT_FUNCTIONS[hf_config.rope_parameters["rope_type"]] + inv_freq, attention_scaling = rope_init_fn(hf_config, device=None) + + config = Config( + hidden_size=hf_config.hidden_size, + intermediate_size=hf_config.intermediate_size // dist.get_world_size(), + num_heads=hf_config.num_attention_heads, + head_dim=hf_config.head_dim, + num_kv_heads=hf_config.num_key_value_heads, + norm_eps=hf_config.rms_norm_eps, + num_layers=hf_config.num_hidden_layers, + num_experts_per_tok=hf_config.num_experts_per_tok, + num_experts=hf_config.num_local_experts, + rope_inv_freq=np.asarray(inv_freq, dtype=np.float32), + rope_attention_scaling=float(attention_scaling), + layer_types=list(hf_config.layer_types), + sliding_window=hf_config.sliding_window, + swiglu_alpha=getattr(hf_config, "swiglu_alpha", 1.702), + swiglu_limit=hf_config.swiglu_limit, + context_len=context_len, + max_new_tokens=max_new_tokens, + ) + return config diff --git a/examples/models/gpt_oss/eagle/README.md b/examples/models/gpt_oss/eagle/README.md new file mode 100644 index 0000000..b7fe083 --- /dev/null +++ b/examples/models/gpt_oss/eagle/README.md @@ -0,0 +1,190 @@ +# P-EAGLE Speculative Decoding for gpt-oss on Trainium + +Parallel-drafting speculative decoding using [P-EAGLE](https://arxiv.org/abs/2602.01469) +for the gpt-oss model family on AWS Trainium. Generates K draft tokens in a +**single forward pass** (not K sequential passes), then verifies them against the +target in one multi-token target forward. + +## Setup + +``` sh +cd nkipy +uv sync --all-groups +source .venv/bin/activate +cd examples/models/gpt_oss +``` + +## Quickstart + +### 1. Prepare weights + +The target model (gpt-oss-20b) must already be prepared (see `../README.md`): + +``` sh +# Target (if not already done) +python tensor_preparation.py \ + --model-name /path/to/gpt-oss-20b \ + --world-size 4 --output-dir ./tmp_gpt-oss-20b + +# Drafter (P-EAGLE, replicated on every rank — small, ~3.6 GB) +python eagle/tensor_preparation.py \ + --model-name /path/to/GPT-OSS-20B-P-EAGLE \ + --output-dir ./eagle/tmp_p-eagle +``` + +### 2. Run speculative decoding + +``` sh +TP=4 +torchrun --nproc-per-node $TP eagle/speculate.py \ + --target-checkpoint ./tmp_gpt-oss-20b \ + --draft-checkpoint ./eagle/tmp_p-eagle \ + --model /path/to/gpt-oss-20b \ + --draft-model /path/to/GPT-OSS-20B-P-EAGLE \ + -n 256 -k 7 \ + "Write a Python function that implements binary search." +``` + +Output includes acceptance metrics: + +``` +Time to first token: 0.6s +Generated 256 tokens in N verify steps +Mean acceptance length: X.XX (K=7) +Decode tokens/sec: XX.XX +``` + +## How it works + +### Speculation loop + +``` +1. Target prefill on prompt → first token + 3 tapped hidden states +2. Seed: run first token through target decode → hidden states for drafter +3. Loop: + a. Drafter: K tokens in ONE parallel forward pass + b. Target verify: run [last_accepted, draft_0, ..., draft_{K-1}] through + target layers (seq_len = K+1) with block-causal mask + c. Accept: longest prefix where draft[i] == target_argmax[i] + d. Emit accepted tokens + bonus correction token + e. Advance KV cache position by (accepted + 1) +``` + +### P-EAGLE parallel drafting (K tokens in one pass) + +Unlike autoregressive EAGLE which runs K sequential drafter passes, P-EAGLE +generates all K draft tokens simultaneously: + +| Position | Embedding input | Hidden state input | +|----------|----------------|-------------------| +| 0 (NTP) | `embed(last_accepted_token)` | `fc(concat(aux_layer_2, aux_layer_12, aux_layer_21))` — real target hidden | +| 1..K-1 (MTP) | `embed(ptd_token_id)` — placeholder | `fc(mask_hidden)` — learnable shared hidden | + +All K positions attend under a **cross-depth causal mask** (depth d sees depths +≤ d) through the EAGLE-3 fusion midlayer + 3 plain Llama decoder layers. Each +position's `lm_head` logit gives one draft token. + +### Architecture details + +The P-EAGLE drafter (`GPT-OSS-20B-P-EAGLE`, ~3.6 GB bf16): + +| Component | Description | +|-----------|-------------| +| `fc` (8640→2880) | Fuses 3 target hidden states (layers 2, 12, 21 of 24-layer target) | +| `midlayer` | EAGLE-3 fusion decoder layer: attention takes 2×hidden (embed⊕hidden), has `hidden_norm` | +| `layers.1/2/3` | Plain Llama decoder layers (SiLU MLP, llama3 RoPE) | +| `mask_hidden` (1,1,8640) | Learnable shared hidden state for MTP positions | +| `ptd_token_id` = 201020 | Placeholder token whose embedding fills MTP positions | +| `d2t` / `t2d` | Draft↔target vocab mapping (identity for this checkpoint) | +| `lm_head` (2880→201088) | Full target vocab, replicated on every rank | + +### Verification + +The target verifies K+1 candidate tokens in a single multi-token forward pass: +- Runs the full gpt-oss decoder stack with `seq_len = K+1` at a runtime offset +- Uses absolute-position RoPE and a block-causal attention mask +- Writes K+1 new KV cache entries contiguously +- Produces per-position greedy argmax via cross-rank reduction + +**Greedy acceptance makes KV rollback implicit**: rejected speculative entries are +overwritten by the next verify pass, and the causal mask prevents any query from +attending past its own position. + +## Files + +| File | Purpose | +|------|---------| +| `speculate.py` | Main entry: speculation loop orchestrating target + drafter | +| `config.py` | `EagleConfig` for the P-EAGLE drafter (llama3 RoPE, fc, mask_hidden, K) | +| `tensor_preparation.py` | Convert P-EAGLE checkpoint to x@W form (replicated, no TP) | +| `drafter_model.py` | Device-side drafter: loads weights, compiles kernel, runs draft | +| `kernels/drafter.py` | Parallel-drafting forward kernel (K tokens in one pass) | +| `kernels/drafter_layer.py` | EAGLE-3 fusion midlayer + plain Llama layers | +| `kernels/verify.py` | Multi-position greedy argmax for verification | +| `kernels/rope.py` | llama3 RoPE (different from target's YaRN RoPE) | +| `kernels/rmsnorm.py`, `softmax.py` | Leaf kernels (copied from base) | + +## Validation + +| What | Result | +|------|--------| +| Drafter kernel math | ✅ All 7 draft tokens match independent PyTorch reference | +| Speculation output | ✅ Lossless — output matches HF greedy baseline exactly | +| Drafter with HF hidden states | ✅ Draft[0] matches target greedy perfectly | +| Multi-token verify | ✅ Block-causal mask + KV scatter correct | + +## Known limitation: acceptance length + +The current CPU-side acceptance length is ~1.4 tokens/step (vs paper's ~3.7 for +GPT-OSS 20B at K=5). The NTP (depth 0) position works correctly — draft[0] +frequently matches the target's greedy. The MTP (depth 1+) positions +underperform, producing generic tokens instead of context-specific ones. + +### What's been verified + +- Drafter NTP produces the correct next token when given HF hidden states ✅ +- Drafter KV cache is necessary and improves acceptance from 1.0 to 1.4 ✅ +- The EAGLE shifted-token convention (input_ids shifted +1 vs hidden states) + matches vLLM's implementation ✅ +- Hidden-state capture point (output of tap layers) matches what vLLM uses ✅ +- The midlayer concat `[norm(embed), norm(hidden)]` → 2H → attention → H output + with H-wide residual matches vLLM's `Eagle3DecoderLayer` ✅ + +### vLLM reference (parallel_drafting) + +Studied from the installed vLLM at `private-vllm-neuron/.venv`. Key findings: + +1. vLLM's parallel drafting produces ALL K tokens in **one forward pass**: the + expanded input contains [shifted context tokens | bonus (next_token) | K-1 + ptd_token positions]. All go through the model together with PagedAttention. + +2. The `parallel_drafting_hidden_state_tensor` = `fc(mask_hidden)` (the fc-fused + mask_hidden at 2880 dim), placed at the MTP positions in the hidden_states + input to the model. + +3. The Triton kernel `copy_and_expand_eagle_inputs_kernel` handles the layout: + positions are sequential (start_pos + j), parallel-draft slots get + `ptd_token_id` for input_ids and `parallel_drafting_hidden_state_tensor` for + hidden_states. + +4. The model's `forward(input_ids, positions, hidden_states)` takes all three + as separate tensors of the same length. Only the midlayer (layer 0) + concatenates embeds with hidden to produce 2H; subsequent layers are standard. + +### Remaining gap to investigate + +The MTP positions have correct architecture but produce generic predictions. The +most likely remaining issue is an off-by-one in how the drafter's RoPE positions +map to the target's absolute positions during the KV-cached speculation loop. +vLLM assigns positions sequentially from the start of the context, and the +parallel-draft positions get positions immediately following the last valid token. +Our `drafter_cpu.py` does the same via `torch.arange(cache_len, cache_len + K)`. + +### Path forward + +1. Run the drafter via vLLM on GPU with this exact checkpoint and capture the + actual acceptance length (confirms the checkpoint quality ceiling) +2. If vLLM achieves ~3.7, the issue is in our inference loop (position/hidden + alignment during rollback+extend) +3. If vLLM also gets ~1.4, the checkpoint may be under-trained for this prompt + distribution (the paper evaluates on HumanEval/MT-Bench/GSM-8K specifically) diff --git a/examples/models/gpt_oss/eagle/__init__.py b/examples/models/gpt_oss/eagle/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/models/gpt_oss/eagle/config.py b/examples/models/gpt_oss/eagle/config.py new file mode 100644 index 0000000..e1ab4a6 --- /dev/null +++ b/examples/models/gpt_oss/eagle/config.py @@ -0,0 +1,84 @@ +"""Configuration for the P-EAGLE parallel-drafting drafter. + +The drafter is a small Llama-style model trained for a specific gpt-oss target. +It generates K draft tokens in a single forward pass (see the project memory and +arXiv 2602.01469). Its structure, from the checkpoint: + + * ``midlayer`` - the EAGLE-3 fusion decoder layer (layer 0). Its attention + projections take 2*hidden (embedding concat hidden), and it owns the extra + ``hidden_norm``. + * ``layers.1 .. layers.{N-1}`` - plain Llama decoder layers. + * ``fc`` - fuses the 3 concatenated target hidden states (3*hidden) -> hidden. + * ``mask_hidden`` - learnable shared hidden state for MTP (depth>0) positions. + * ``ptd_token_id`` - placeholder token id whose embedding substitutes the + unknown previous token at MTP positions. + * ``d2t`` / ``t2d`` - draft<->target vocab maps (identity for this checkpoint). +""" + +from dataclasses import dataclass + +import numpy as np +import torch.distributed as dist +from neuronxcc.nki.language import bfloat16 +from transformers import AutoConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + +DTYPE = bfloat16 + + +@dataclass +class EagleConfig: + hidden_size: int + num_heads: int + head_dim: int + num_kv_heads: int + num_layers: int # total drafter decoder layers (incl. the fusion midlayer) + intermediate_size: int + # llama3 RoPE. + rope_inv_freq: np.ndarray + rope_attention_scaling: float + # Target-model hidden size whose 3 tapped layers feed `fc` (3*target_hidden). + target_hidden_size: int + # Draft / target vocab sizes (equal for this checkpoint). + draft_vocab_size: int + target_vocab_size: int + # Placeholder token id used at MTP (depth>0) positions. + ptd_token_id: int + # Number of draft tokens produced per parallel forward pass. + num_draft_tokens: int = 7 + norm_eps: float = 1e-5 + max_seq_len: int = 4096 + max_batch_size: int = 1 + dtype: np.dtype = DTYPE + additional_compiler_args_nkipy: str = "--lnc 1" + + +def get_eagle_config( + draft_model_name, + target_hidden_size, + num_draft_tokens=7, + max_seq_len=4096, +): + hf = AutoConfig.from_pretrained(draft_model_name) + + # llama3 RoPE: precompute inverse frequencies + attention scaling once. + rope_init_fn = ROPE_INIT_FUNCTIONS[hf.rope_scaling["rope_type"]] + inv_freq, attention_scaling = rope_init_fn(hf, device=None) + + return EagleConfig( + hidden_size=hf.hidden_size, + num_heads=hf.num_attention_heads, + head_dim=hf.head_dim, + num_kv_heads=hf.num_key_value_heads, + num_layers=hf.num_hidden_layers, + intermediate_size=hf.intermediate_size // dist.get_world_size(), + rope_inv_freq=np.asarray(inv_freq, dtype=np.float32), + rope_attention_scaling=float(attention_scaling), + target_hidden_size=target_hidden_size, + draft_vocab_size=hf.draft_vocab_size, + target_vocab_size=hf.vocab_size, + ptd_token_id=getattr(hf, "ptd_token_id"), + num_draft_tokens=num_draft_tokens, + norm_eps=hf.rms_norm_eps, + max_seq_len=max_seq_len, + ) diff --git a/examples/models/gpt_oss/eagle/drafter_cpu.py b/examples/models/gpt_oss/eagle/drafter_cpu.py new file mode 100644 index 0000000..4bff2d3 --- /dev/null +++ b/examples/models/gpt_oss/eagle/drafter_cpu.py @@ -0,0 +1,258 @@ +"""CPU-side P-EAGLE drafter with KV cache for speculative decoding. + +The P-EAGLE drafter maintains its own KV cache across the full context (prompt + +accepted tokens). At each draft step, K positions (1 NTP + K-1 MTP) attend to +the full accumulated cache via standard causal attention. After acceptance, the +accepted tokens' (embedding, target hidden) pairs extend the cache. + +This runs entirely on CPU (the drafter is tiny — 4 layers, ~3.6 GB bf16). The +algorithm correctness is independent of where computation happens; this can be +moved to device later for throughput. +""" + +import torch +import torch.nn.functional as F +from safetensors import safe_open +from transformers import AutoConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + + +class DrafterCPU: + def __init__(self, model_path, target_hidden_size, num_draft_tokens=7): + self.config = AutoConfig.from_pretrained(model_path) + self.H = self.config.hidden_size + self.K = num_draft_tokens + self.target_hidden_size = target_hidden_size + self.eps = self.config.rms_norm_eps + self.n_heads = self.config.num_attention_heads + self.n_kv = self.config.num_key_value_heads + self.head_dim = self.config.head_dim + self.n_layers = self.config.num_hidden_layers # 4 (midlayer + 3 plain) + self.ptd_token_id = self.config.ptd_token_id + + # Load weights. + with safe_open(f"{model_path}/model.safetensors", framework="pt") as f: + self.w = {k: f.get_tensor(k).to(torch.bfloat16) for k in f.keys()} + + # Precompute RoPE. + fn = ROPE_INIT_FUNCTIONS[self.config.rope_scaling["rope_type"]] + inv_freq, self.rope_scaling = fn(self.config, None) + self.inv_freq = inv_freq.float() + + # KV caches: list of (k, v) per layer, each (B, seq, n_kv, head_dim). + self.kv_caches = None + self.cache_len = 0 + + def reset(self): + self.kv_caches = [None] * self.n_layers + self.cache_len = 0 + + def rollback(self, new_len): + """Truncate KV caches to new_len (discard rejected speculative entries).""" + for i in range(self.n_layers): + if self.kv_caches[i] is not None: + k, v = self.kv_caches[i] + self.kv_caches[i] = (k[:, :new_len], v[:, :new_len]) + self.cache_len = new_len + + # ── Building blocks ────────────────────────────────────────────────────── + + def _rms(self, x, w): + x = x.float() + return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * w.float()).to(torch.bfloat16) + + def _rope_cos_sin(self, positions): + """positions: 1-D int tensor of absolute positions.""" + freqs = torch.outer(positions.float(), self.inv_freq) + cos = (freqs.cos() * self.rope_scaling).to(torch.bfloat16) + sin = (freqs.sin() * self.rope_scaling).to(torch.bfloat16) + return cos, sin # (S, head_dim/2) + + def _apply_rope(self, x, cos, sin): + """x: (B, H, S, D); cos/sin: (S, D/2).""" + h = x.shape[-1] // 2 + cos = cos[None, None, :, :] # (1,1,S,D/2) + sin = sin[None, None, :, :] + x0, x1 = x[..., :h], x[..., h:] + return torch.cat([x0 * cos - x1 * sin, x1 * cos + x0 * sin], dim=-1) + + def _attention(self, layer_idx, q_proj, k_proj, v_proj, o_proj, x, positions): + """Self-attention with KV cache.""" + B, S, _ = x.shape + nh, nkv, hd = self.n_heads, self.n_kv, self.head_dim + rep = nh // nkv + + q = (x @ q_proj).view(B, S, nh, hd).transpose(1, 2) # (B, nh, S, hd) + k = (x @ k_proj).view(B, S, nkv, hd).transpose(1, 2) # (B, nkv, S, hd) + v = (x @ v_proj).view(B, S, nkv, hd).transpose(1, 2) + + # RoPE on the NEW positions only. + cos, sin = self._rope_cos_sin(positions) + q = self._apply_rope(q, cos, sin) + k = self._apply_rope(k, cos, sin) + + # Update KV cache. + if self.kv_caches[layer_idx] is None: + self.kv_caches[layer_idx] = (k, v) + else: + pk, pv = self.kv_caches[layer_idx] + self.kv_caches[layer_idx] = (torch.cat([pk, k], dim=2), torch.cat([pv, v], dim=2)) + + # Full keys/values (cached + new). + full_k, full_v = self.kv_caches[layer_idx] # (B, nkv, total_len, hd) + full_k = full_k.repeat_interleave(rep, dim=1) + full_v = full_v.repeat_interleave(rep, dim=1) + + # Attention scores: q attends to full KV. + total_len = full_k.shape[2] + scores = (q @ full_k.transpose(2, 3)) / (hd ** 0.5) + + # Causal mask: position i (absolute) can attend to positions <= i. + # Query positions are `positions`; key positions are 0..total_len-1. + # Build mask: (S, total_len) where mask[i,j] = 0 if key_pos[j] <= query_pos[i], else -inf. + key_pos = torch.arange(total_len, device=positions.device) + mask = (key_pos[None, :] > positions[:, None]).float() * (-1e5) + scores = scores + mask[None, None, :, :] # broadcast over (B, nh) + + attn = F.softmax(scores.float(), dim=-1).to(torch.bfloat16) + out = (attn @ full_v).transpose(1, 2).reshape(B, S, nh * hd) + return out @ o_proj + + def _mlp(self, prefix, x): + w = self.w + gate = F.silu(x @ w[f"{prefix}.mlp.gate_proj.weight"].T) + up = x @ w[f"{prefix}.mlp.up_proj.weight"].T + return (gate * up) @ w[f"{prefix}.mlp.down_proj.weight"].T + + def _run_layers(self, x_2h, positions): + """Run all drafter layers. x_2h: (B, S, 2H) concatenated [emb, hidden].""" + w = self.w + H = self.H + + # ── Fusion midlayer (layer 0) ── + emb = x_2h[:, :, :H] + hidden = x_2h[:, :, H:] + residual = hidden + hn = self._rms(hidden, w["midlayer.hidden_norm.weight"]) + en = self._rms(emb, w["midlayer.input_layernorm.weight"]) + attn_in = torch.cat([en, hn], dim=-1) # (B, S, 2H) + + attn_out = self._attention( + 0, + w["midlayer.self_attn.q_proj.weight"].T, + w["midlayer.self_attn.k_proj.weight"].T, + w["midlayer.self_attn.v_proj.weight"].T, + w["midlayer.self_attn.o_proj.weight"].T, + attn_in, + positions, + ) + x = residual + attn_out + x = x + self._mlp("midlayer", self._rms(x, w["midlayer.post_attention_layernorm.weight"])) + + # ── Plain layers 1..N-1 ── + for i in range(1, self.n_layers): + p = f"layers.{i}" + residual = x + xn = self._rms(x, w[f"{p}.input_layernorm.weight"]) + attn_out = self._attention( + i, + w[f"{p}.self_attn.q_proj.weight"].T, + w[f"{p}.self_attn.k_proj.weight"].T, + w[f"{p}.self_attn.v_proj.weight"].T, + w[f"{p}.self_attn.o_proj.weight"].T, + xn, + positions, + ) + x = residual + attn_out + x = x + self._mlp(p, self._rms(x, w[f"{p}.post_attention_layernorm.weight"])) + + return x + + def _fc_fuse(self, hidden3): + """Project 3*target_hidden → hidden via fc weight.""" + return hidden3 @ self.w["fc.weight"].T + + # ── Public API ─────────────────────────────────────────────────────────── + + @torch.no_grad() + def prefill(self, token_ids, aux_hidden_states): + """Fill drafter KV cache with prompt context. + + Args: + token_ids: (prompt_len,) int tensor of prompt tokens. + aux_hidden_states: (1, prompt_len, 3*target_H) concatenated tap-layer + hidden states from the target's prefill. + """ + self.reset() + S = len(token_ids) + emb = self.w["embed_tokens.weight"][token_ids].unsqueeze(0) # (1, S, H) + hidden = self._fc_fuse(aux_hidden_states) # (1, S, H) + x_2h = torch.cat([emb, hidden], dim=-1) # (1, S, 2H) + positions = torch.arange(S) + self._run_layers(x_2h, positions) + self.cache_len = S + + @torch.no_grad() + def draft(self, target_aux3, last_token_id, accepted_tokens=None, accepted_aux=None): + """Generate K draft tokens attending to the full cached context. + + Args: + target_aux3: (1, 1, 3*target_H) target hidden at the last accepted pos. + last_token_id: int, the last accepted token. + accepted_tokens: list[int], newly accepted tokens to add to cache first. + accepted_aux: (1, n_accepted, 3*target_H) their target hidden states. + + Returns: + list[int] of K draft token ids. + """ + H = self.H + new_positions = [] + + # Step 1: Extend cache with newly accepted tokens (if any). + if accepted_tokens is not None and len(accepted_tokens) > 0: + A = len(accepted_tokens) + acc_emb = self.w["embed_tokens.weight"][torch.tensor(accepted_tokens)].unsqueeze(0) + acc_hidden = self._fc_fuse(accepted_aux) + acc_2h = torch.cat([acc_emb, acc_hidden], dim=-1) + acc_pos = torch.arange(self.cache_len, self.cache_len + A) + self._run_layers(acc_2h, acc_pos) + self.cache_len += A + + # Step 2: Build K positions (1 NTP + K-1 MTP). + ntp_emb = self.w["embed_tokens.weight"][last_token_id].view(1, 1, H) + ntp_hidden = self._fc_fuse(target_aux3) # (1, 1, H) + + K = self.K + if K > 1: + ptd_emb = self.w["embed_tokens.weight"][self.ptd_token_id].view(1, 1, H) + mtp_embs = ptd_emb.expand(1, K - 1, H) + mask_hidden = self._fc_fuse( + self.w["mask_hidden"].view(1, 1, -1) + ).expand(1, K - 1, H) + embs = torch.cat([ntp_emb, mtp_embs], dim=1) + hiddens = torch.cat([ntp_hidden, mask_hidden], dim=1) + else: + embs = ntp_emb + hiddens = ntp_hidden + + x_2h = torch.cat([embs, hiddens], dim=-1) # (1, K, 2H) + draft_positions = torch.arange(self.cache_len, self.cache_len + K) + + # Run through layers (this extends the KV cache with K new entries). + x = self._run_layers(x_2h, draft_positions) + + # DON'T advance cache_len here — these are speculative positions. + # They'll be rolled back if rejected. But we DO keep them in the cache + # temporarily for the KV consistency. + self.cache_len += K + + # Logits → draft tokens. + x = self._rms(x, self.w["norm.weight"]) + logits = (x @ self.w["lm_head.weight"].T).float() # (1, K, vocab) + draft_ids = logits[0].argmax(dim=-1).tolist() + + # Map draft vocab → target vocab (identity for this checkpoint). + d2t = self.w["d2t"].long() + draft_ids = [(d + int(d2t[d])) for d in draft_ids] + + return draft_ids diff --git a/examples/models/gpt_oss/eagle/drafter_model.py b/examples/models/gpt_oss/eagle/drafter_model.py new file mode 100644 index 0000000..d61994c --- /dev/null +++ b/examples/models/gpt_oss/eagle/drafter_model.py @@ -0,0 +1,188 @@ +"""Device-side P-EAGLE drafter: loads shards, compiles the parallel-draft kernel.""" + +import time + +import numpy as np +import torch +from nkipy.runtime import DeviceKernel, DeviceTensor + +from .config import EagleConfig +from .kernels.drafter import drafter_kernel + +BUILD_DIR = None # set by the caller (absolute path) + + +class DrafterModel: + def __init__(self, weights, config: EagleConfig, build_dir): + self.config = config + self.build_dir = build_dir + self.kernel = None + self._prepare_tensors(weights) + self._prepare_kernel() + + def _dt(self, t, name): + return DeviceTensor.from_torch(t, name) + + def _prepare_tensors(self, w): + cfg = self.config + H = cfg.hidden_size + + # Shared tensors. + self.embed_tokens = w["embed_tokens"] # host, for embedding lookups + self.fc_weight = self._dt(w["fc_weight"], "d_fc_weight") + self.mask_hidden = self._dt(w["mask_hidden"], "d_mask_hidden") + self.norm_weight = self._dt(w["norm_weight"], "d_norm_weight") + self.lm_head_weight = self._dt(w["lm_head_weight"], "d_lm_head_weight") + self.d2t = w["d2t"].to(torch.int64) + self.ptd_emb = self._dt( + self.embed_tokens[cfg.ptd_token_id].reshape(1, H), "d_ptd_emb" + ) + + # Fusion midlayer weights. + self.m = { + k: self._dt(w[f"midlayer.{k}"], f"d_m_{k}") + for k in [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "input_weight", + "hidden_norm_weight", + "post_attention_weight", + "gate_proj", + "up_proj", + "down_proj", + ] + } + + # Plain layers stacked on a leading axis (layers 1..N-1). + plain_keys = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "input_weight", + "post_attention_weight", + "gate_proj", + "up_proj", + "down_proj", + ] + self.p = {} + for k in plain_keys: + stacked = torch.stack( + [w[f"layers.{i}.{k}"] for i in range(1, cfg.num_layers)] + ) + self.p[k] = self._dt(stacked, f"d_p_{k}") + + def _prepare_kernel(self): + cfg = self.config + H = cfg.hidden_size + B = cfg.max_batch_size + target3 = DeviceTensor.from_numpy( + np.empty((B, 1, 3 * cfg.target_hidden_size), dtype=cfg.dtype), "d_target3" + ) + last_emb = DeviceTensor.from_numpy( + np.empty((B, 1, H), dtype=cfg.dtype), "d_last_emb" + ) + t = time.time() + self.kernel = DeviceKernel.compile_and_load( + drafter_kernel, + name="drafter", + target_hidden3=target3, + last_emb=last_emb, + ptd_emb=self.ptd_emb, + fc_weight=self.fc_weight, + mask_hidden=self.mask_hidden, + norm_weight=self.norm_weight, + lm_head_weight=self.lm_head_weight, + m_q_proj=self.m["q_proj"], + m_k_proj=self.m["k_proj"], + m_v_proj=self.m["v_proj"], + m_o_proj=self.m["o_proj"], + m_input_weight=self.m["input_weight"], + m_hidden_norm_weight=self.m["hidden_norm_weight"], + m_post_attention_weight=self.m["post_attention_weight"], + m_gate_proj=self.m["gate_proj"], + m_up_proj=self.m["up_proj"], + m_down_proj=self.m["down_proj"], + p_q_proj=self.p["q_proj"], + p_k_proj=self.p["k_proj"], + p_v_proj=self.p["v_proj"], + p_o_proj=self.p["o_proj"], + p_input_weight=self.p["input_weight"], + p_post_attention_weight=self.p["post_attention_weight"], + p_gate_proj=self.p["gate_proj"], + p_up_proj=self.p["up_proj"], + p_down_proj=self.p["down_proj"], + cfg=cfg, + build_dir=self.build_dir, + additional_compiler_args=cfg.additional_compiler_args_nkipy, + ) + self._draft_logits = DeviceTensor.from_numpy( + np.empty( + (B, cfg.num_draft_tokens, self.lm_head_weight.shape[1]), + dtype=cfg.dtype, + ), + "d_draft_logits", + ) + self._compile_time = time.time() - t + + def draft(self, target_hidden3, last_token_id): + """Produce K draft token ids (global vocab) from the 3 tapped target hiddens. + + Args: + target_hidden3: host tensor (B, 1, 3*target_hidden). + last_token_id: int id of the last accepted token (B==1 assumed here). + Returns: + list[int] of K draft token ids in the target vocabulary. + """ + cfg = self.config + H = cfg.hidden_size + target3_dev = DeviceTensor.from_torch( + target_hidden3.to(torch.bfloat16), "target3_in" + ) + last_emb = DeviceTensor.from_torch( + self.embed_tokens[last_token_id].reshape(1, 1, H), "last_emb_in" + ) + + self.kernel( + inputs={ + "target_hidden3": target3_dev, + "last_emb": last_emb, + "ptd_emb": self.ptd_emb, + "fc_weight": self.fc_weight, + "mask_hidden": self.mask_hidden, + "norm_weight": self.norm_weight, + "lm_head_weight": self.lm_head_weight, + "m_q_proj": self.m["q_proj"], + "m_k_proj": self.m["k_proj"], + "m_v_proj": self.m["v_proj"], + "m_o_proj": self.m["o_proj"], + "m_input_weight": self.m["input_weight"], + "m_hidden_norm_weight": self.m["hidden_norm_weight"], + "m_post_attention_weight": self.m["post_attention_weight"], + "m_gate_proj": self.m["gate_proj"], + "m_up_proj": self.m["up_proj"], + "m_down_proj": self.m["down_proj"], + "p_q_proj": self.p["q_proj"], + "p_k_proj": self.p["k_proj"], + "p_v_proj": self.p["v_proj"], + "p_o_proj": self.p["o_proj"], + "p_input_weight": self.p["input_weight"], + "p_post_attention_weight": self.p["post_attention_weight"], + "p_gate_proj": self.p["gate_proj"], + "p_up_proj": self.p["up_proj"], + "p_down_proj": self.p["down_proj"], + }, + outputs={"output0": self._draft_logits}, + ) + + # Local (per-rank) draft logits -> draft token ids. We argmax over this + # rank's vocab shard and remap via d2t; with vocab sharding the caller + # reduces across ranks (see speculate.py). For the common single-vocab + # checkpoint (lm_head replicated) this is already the global argmax. + logits = self._draft_logits.torch().float() # (B, K, vocab_local) + draft_local = logits.argmax(dim=-1)[0] # (K,) + # Map draft-vocab id -> target-vocab id (identity when d2t is all-zero). + draft_global = draft_local + self.d2t[draft_local] + return draft_global.tolist() diff --git a/examples/models/gpt_oss/eagle/kernels/__init__.py b/examples/models/gpt_oss/eagle/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/models/gpt_oss/eagle/kernels/drafter.py b/examples/models/gpt_oss/eagle/kernels/drafter.py new file mode 100644 index 0000000..b5bbe7c --- /dev/null +++ b/examples/models/gpt_oss/eagle/kernels/drafter.py @@ -0,0 +1,201 @@ +"""P-EAGLE parallel drafter forward: produce K draft-token logits in one pass. + +Input construction (K positions = depths 0..K-1): + * depth 0 (NTP): real fused target hidden (fc(cat of 3 tapped layers)) and the + embedding of the last accepted token. + * depth d>0 (MTP): the learnable shared hidden state `mask_hidden` (fused via + `fc`) and the embedding of the placeholder token `ptd_token_id`. + +The K positions run through the fusion midlayer then the plain layers under a +cross-depth causal mask (depth d attends to depths <= d). Each position's +post-norm hidden is projected by `lm_head` to a draft logit row; the argmax (then +`d2t` remap) gives that depth's draft token. + +Weights are passed as a flat dict keyed exactly as produced by +``eagle/tensor_preparation.py``. +""" + +import neuronxcc.nki.language as nl +import numpy as np +from nkipy.core import tensor_apis + +from .drafter_layer import drafter_layer +from .rmsnorm import rmsnorm_kernel +from .rope import compute_cos_sin_cache + + +def _cross_depth_mask(K, dtype): + """Additive (K, K) mask: depth d attends to depths <= d (lower-triangular).""" + NEG = -100000.0 + return np.triu(np.ones((K, K)) * NEG, k=1).astype(dtype) + + +def drafter_forward( + fused_hidden, # (B, 1, hidden): fc-fused real target hidden for NTP position + last_emb, # (B, 1, hidden): embedding of last accepted token + mask_hidden_fused, # (1, hidden): fc(mask_hidden), shared MTP hidden + ptd_emb, # (1, hidden): embedding of ptd_token_id + layer_weights, # list[dict]: per-layer weight dicts (idx 0 = fusion midlayer) + norm_weight, + lm_head_weight, + cfg, +): + """Return draft logits of shape (B, K, draft_vocab_local).""" + B = fused_hidden.shape[0] + K = cfg.num_draft_tokens + H = cfg.hidden_size + + # ── Build the K-position input stream: cat(embedding, hidden) per depth ── + # depth 0 uses real hidden + last token embedding; depths>0 use the shared + # mask hidden + placeholder embedding. + emb_cols = [last_emb] + hid_cols = [fused_hidden] + if K > 1: + ptd = np.broadcast_to(ptd_emb.reshape(1, 1, H), (B, K - 1, H)) + msk = np.broadcast_to(mask_hidden_fused.reshape(1, 1, H), (B, K - 1, H)) + emb_cols.append(ptd.astype(last_emb.dtype)) + hid_cols.append(msk.astype(fused_hidden.dtype)) + embeds = np.concatenate(emb_cols, axis=1) # (B, K, H) + hiddens = np.concatenate(hid_cols, axis=1) # (B, K, H) + + # Fusion midlayer consumes cat(embeds, hidden) of width 2H. + x = np.concatenate([embeds, hiddens], axis=-1) # (B, K, 2H) + + # RoPE cache + cross-depth mask (compile-time constants). + freqs_cos, freqs_sin = compute_cos_sin_cache( + cfg.rope_inv_freq, + cfg.max_seq_len, + cfg.rope_attention_scaling, + dtype=nl.bfloat16, + ) + freqs_cos = freqs_cos[0:K] + freqs_sin = freqs_sin[0:K] + attn_mask = tensor_apis.constant(_cross_depth_mask(K, nl.bfloat16)) + + # ── Run the drafter layer stack ── + for i, w in enumerate(layer_weights): + is_fusion = i == 0 + x = drafter_layer( + x, + w, + cfg.norm_eps, + cfg.num_heads, + cfg.num_kv_heads, + cfg.head_dim, + freqs_cos, + freqs_sin, + attn_mask, + is_fusion, + ) + + # Final norm + draft lm_head. + x = rmsnorm_kernel(x, norm_weight, cfg.norm_eps) + logits = np.matmul(x, lm_head_weight) # (B, K, draft_vocab_local) + return logits + + +# Per-layer weight key suffixes, in the order the prep step emits them. +_FUSION_W = ( + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "input_weight", + "hidden_norm_weight", + "post_attention_weight", + "gate_proj", + "up_proj", + "down_proj", +) +_PLAIN_W = ( + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "input_weight", + "post_attention_weight", + "gate_proj", + "up_proj", + "down_proj", +) + + +def drafter_kernel( + target_hidden3, # (B, 1, 3*target_hidden): the 3 tapped target layers, concatenated + last_emb, # (B, 1, hidden): embedding of last accepted token + ptd_emb, # (1, hidden): placeholder-token embedding + fc_weight, # (3*target_hidden, hidden) + mask_hidden, # (1, 3*target_hidden) learnable shared hidden (pre-fc) + norm_weight, + lm_head_weight, + # fusion midlayer weights + m_q_proj, + m_k_proj, + m_v_proj, + m_o_proj, + m_input_weight, + m_hidden_norm_weight, + m_post_attention_weight, + m_gate_proj, + m_up_proj, + m_down_proj, + # plain layer weights, stacked on a leading axis of size num_plain + p_q_proj, + p_k_proj, + p_v_proj, + p_o_proj, + p_input_weight, + p_post_attention_weight, + p_gate_proj, + p_up_proj, + p_down_proj, + cfg, +): + """Device-traceable drafter entry point with a flat, static signature. + + Does the ``fc`` fusion (real hidden for NTP, ``mask_hidden`` for MTP) inside + the kernel, then delegates to :func:`drafter_forward`. Plain-layer weights are + stacked on a leading axis (one row per plain layer) and indexed statically. + """ + fused_hidden = np.matmul(target_hidden3, fc_weight) # (B, 1, hidden) + mask_hidden_fused = np.matmul(mask_hidden, fc_weight) # (1, hidden) + + midlayer = { + "q_proj": m_q_proj, + "k_proj": m_k_proj, + "v_proj": m_v_proj, + "o_proj": m_o_proj, + "input_layernorm": m_input_weight, + "hidden_norm": m_hidden_norm_weight, + "post_attention_layernorm": m_post_attention_weight, + "gate_proj": m_gate_proj, + "up_proj": m_up_proj, + "down_proj": m_down_proj, + } + layer_weights = [midlayer] + num_plain = p_q_proj.shape[0] + for i in range(num_plain): + layer_weights.append( + { + "q_proj": p_q_proj[i], + "k_proj": p_k_proj[i], + "v_proj": p_v_proj[i], + "o_proj": p_o_proj[i], + "input_layernorm": p_input_weight[i], + "post_attention_layernorm": p_post_attention_weight[i], + "gate_proj": p_gate_proj[i], + "up_proj": p_up_proj[i], + "down_proj": p_down_proj[i], + } + ) + + return drafter_forward( + fused_hidden, + last_emb, + mask_hidden_fused, + ptd_emb, + layer_weights, + norm_weight, + lm_head_weight, + cfg, + ) diff --git a/examples/models/gpt_oss/eagle/kernels/drafter_layer.py b/examples/models/gpt_oss/eagle/kernels/drafter_layer.py new file mode 100644 index 0000000..6874eec --- /dev/null +++ b/examples/models/gpt_oss/eagle/kernels/drafter_layer.py @@ -0,0 +1,136 @@ +"""Decoder layers for the P-EAGLE drafter. + +Two layer shapes share most of the math: + + * The **fusion midlayer** (layer 0) receives ``cat(embeds, hidden)`` of width + ``2*hidden``. It normalizes the two halves separately (``input_layernorm`` on + the embedding half, ``hidden_norm`` on the hidden half), re-concatenates them + as the attention input, and keeps the (un-normalized, or normed-before- + residual) hidden half as the residual. Its QKV projections take ``2*hidden``. + * **Plain layers** are standard Llama decoder layers operating on ``hidden``. + +Both use rotate-halves llama3 RoPE and plain SwiGLU. +""" + +import neuronxcc.nki.language as nl +import numpy as np + +from .rmsnorm import rmsnorm_kernel +from .rope import apply_rotary_emb_kernel +from .softmax import softmax_kernel + + +def _silu(x): + return x * (1.0 / (1.0 + np.exp(-x))) + + +def _mlp(x, gate_proj, up_proj, down_proj): + """Plain Llama SwiGLU MLP. Weights stored in x @ W form (hidden, inter).""" + gate = _silu(np.matmul(x, gate_proj)) + up = np.matmul(x, up_proj) + return np.matmul(gate * up, down_proj) + + +def _repeat_kv(x, n_rep): + if n_rep == 1: + return x + return np.repeat(x, n_rep, axis=2) + + +def _attention( + attn_input, + q_proj, + k_proj, + v_proj, + o_proj, + n_heads, + n_kv_heads, + head_dim, + freqs_cos, + freqs_sin, + attn_mask, +): + """Self-attention over the drafter's K parallel positions (no GQA bias/sink). + + attn_mask is an additive (S, S) cross-depth causal mask (compile-time const). + """ + B, S, _ = attn_input.shape + + xq = np.matmul(attn_input, q_proj).reshape(B, S, n_heads, head_dim) + xk = np.matmul(attn_input, k_proj).reshape(B, S, n_kv_heads, head_dim) + xv = np.matmul(attn_input, v_proj).reshape(B, S, n_kv_heads, head_dim) + + xq, xk = apply_rotary_emb_kernel(xq, xk, freqs_cos, freqs_sin) + + n_rep = n_heads // n_kv_heads + keys = _repeat_kv(xk, n_rep) + values = _repeat_kv(xv, n_rep) + + # BSHD -> BHSD + xq = xq.transpose(0, 2, 1, 3) + keys = keys.transpose(0, 2, 1, 3) + values = values.transpose(0, 2, 1, 3) + + scores = (xq @ keys.transpose(0, 1, 3, 2)) / np.float32(np.sqrt(head_dim)) + scores = scores.astype(nl.bfloat16) + scores = scores + np.expand_dims(attn_mask, axis=[0, 1]) + + weights = softmax_kernel(scores) + out = weights @ values # BHSD + + out = out.transpose(0, 2, 1, 3).reshape(B, S, n_heads * head_dim) + return np.matmul(out, o_proj) + + +def drafter_layer( + x, + weights, + cfg_norm_eps, + n_heads, + n_kv_heads, + head_dim, + freqs_cos, + freqs_sin, + attn_mask, + is_fusion, +): + """Run one drafter decoder layer. + + `weights` is a dict of this layer's numpy/device arrays. For the fusion + midlayer it additionally contains ``hidden_norm`` and `x` is ``2*hidden`` wide; + for plain layers `x` is ``hidden`` wide. + """ + if is_fusion: + hidden_size = x.shape[-1] // 2 + embeds = x[:, :, :hidden_size] + hidden = x[:, :, hidden_size:] + + # norm_before_residual is False for this checkpoint: residual is the raw + # hidden half, hidden_norm is applied only to the attention input. + residual = hidden + hidden_n = rmsnorm_kernel(hidden, weights["hidden_norm"], cfg_norm_eps) + embeds_n = rmsnorm_kernel(embeds, weights["input_layernorm"], cfg_norm_eps) + attn_input = np.concatenate([embeds_n, hidden_n], axis=-1) + else: + residual = x + attn_input = rmsnorm_kernel(x, weights["input_layernorm"], cfg_norm_eps) + + attn_out = _attention( + attn_input, + weights["q_proj"], + weights["k_proj"], + weights["v_proj"], + weights["o_proj"], + n_heads, + n_kv_heads, + head_dim, + freqs_cos, + freqs_sin, + attn_mask, + ) + + h = residual + attn_out + residual = h + h = rmsnorm_kernel(h, weights["post_attention_layernorm"], cfg_norm_eps) + h = _mlp(h, weights["gate_proj"], weights["up_proj"], weights["down_proj"]) + return residual + h diff --git a/examples/models/gpt_oss/eagle/kernels/rmsnorm.py b/examples/models/gpt_oss/eagle/kernels/rmsnorm.py new file mode 100644 index 0000000..60a2d07 --- /dev/null +++ b/examples/models/gpt_oss/eagle/kernels/rmsnorm.py @@ -0,0 +1,21 @@ +import numpy as np + + +def rmsnorm_kernel( + x, + weight, + eps: float, + compute_dtype=np.float32, # reduce numerical error +): + original_dtype = x.dtype + x = x.astype(compute_dtype) + weight = weight.astype(compute_dtype) + z = np.square(x) + z = np.mean(z, axis=-1, keepdims=True) + + z = (z + eps).astype(x.dtype) + z = x / np.sqrt(z) + + res = z * weight + res = res.astype(original_dtype) + return res diff --git a/examples/models/gpt_oss/eagle/kernels/rope.py b/examples/models/gpt_oss/eagle/kernels/rope.py new file mode 100644 index 0000000..b23f377 --- /dev/null +++ b/examples/models/gpt_oss/eagle/kernels/rope.py @@ -0,0 +1,39 @@ +"""llama3 RoPE for the P-EAGLE drafter. + +Same rotate-halves formulation as the gpt-oss base model; only the inverse +frequencies differ (llama3 scaling, precomputed on the host in EagleConfig). +""" + +import numpy as np + + +def compute_cos_sin_cache( + inv_freq, max_seq_len, attention_scaling=1.0, dtype=np.float32 +): + inv_freq = np.asarray(inv_freq, dtype=np.float32) + t = np.arange(max_seq_len, dtype=np.float32) + freqs = np.outer(t, inv_freq) + cos = (np.cos(freqs, dtype=np.float32) * attention_scaling).astype(dtype) + sin = (np.sin(freqs, dtype=np.float32) * attention_scaling).astype(dtype) + return cos, sin + + +def apply_rotary_emb_kernel(xq, xk, freqs_cos, freqs_sin): + """Rotate-halves RoPE on query/key tensors of shape (B, S, H, D).""" + freqs_cos = np.expand_dims(freqs_cos, axis=(0, 2)) + freqs_sin = np.expand_dims(freqs_sin, axis=(0, 2)) + + half_h = xq.shape[-1] // 2 + xq0 = xq[:, :, :, :half_h] + xq1 = xq[:, :, :, half_h:] + xk0 = xk[:, :, :, :half_h] + xk1 = xk[:, :, :, half_h:] + + xq_out_0 = xq0 * freqs_cos - xq1 * freqs_sin + xq_out_1 = xq1 * freqs_cos + xq0 * freqs_sin + xk_out_0 = xk0 * freqs_cos - xk1 * freqs_sin + xk_out_1 = xk1 * freqs_cos + xk0 * freqs_sin + + xq_out = np.concatenate([xq_out_0, xq_out_1], axis=-1) + xk_out = np.concatenate([xk_out_0, xk_out_1], axis=-1) + return xq_out, xk_out diff --git a/examples/models/gpt_oss/eagle/kernels/softmax.py b/examples/models/gpt_oss/eagle/kernels/softmax.py new file mode 100644 index 0000000..2f9a948 --- /dev/null +++ b/examples/models/gpt_oss/eagle/kernels/softmax.py @@ -0,0 +1,6 @@ +import numpy as np + + +def softmax_kernel(x): + exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) + return exp_x / np.sum(exp_x, axis=-1, keepdims=True) diff --git a/examples/models/gpt_oss/eagle/kernels/verify.py b/examples/models/gpt_oss/eagle/kernels/verify.py new file mode 100644 index 0000000..6a3e1de --- /dev/null +++ b/examples/models/gpt_oss/eagle/kernels/verify.py @@ -0,0 +1,59 @@ +"""Verification sampling for speculative decoding. + +After the target model runs the K+1 candidate tokens through its decoder stack +(reusing the base ``transformer_layer`` in verify mode), this kernel turns the +final hidden states into one greedy target token per position. + +Position i predicts the token that *follows* candidate i. With greedy +verification, draft token d_{i+1} is accepted iff it equals the target argmax at +position i; the first mismatch's target argmax becomes the bonus/correction +token. The caller does the accept/reject comparison on host. +""" + +import nkipy.distributed.collectives as cc +import numpy as np +import torch.distributed as dist +from nkipy.core import tensor_apis + +from .rmsnorm import rmsnorm_kernel + + +def verify_argmax(h, norm_weight, lm_head_weight, configs): + """Return the target greedy token id at each of the S positions. + + Args: + h: (B, S, H) final hidden states for the S=K+1 candidate positions. + Returns: + (B, S) int array of target argmax token ids (global vocab). + """ + B, S, H = h.shape + world = dist.get_world_size() + h = rmsnorm_kernel(h, norm_weight, configs.norm_eps) + vocab_per_device = lm_head_weight.shape[1] + + # All S positions at once: logits (B, S, vocab_local). + logits = np.matmul(h, lm_head_weight) + local_val, local_idx = tensor_apis.topk(logits, k=1, axis=-1) # (B, S, 1) + local_val = local_val[:, :, 0] # (B, S) + local_idx = local_idx[:, :, 0] + + # Gather every rank's local winner along a new leading axis, then pick the + # winning rank per (B, S) and map back to the global vocab id. All ops stay on + # traced tensors (no Python-scalar assignment into a numpy buffer). + val_all = cc.all_gather( + local_val, all_gather_dim=0, replica_groups=[list(range(world))] + ) # (world*B, S) + idx_all = cc.all_gather( + local_idx, all_gather_dim=0, replica_groups=[list(range(world))] + ) + val_all = val_all.reshape(world, B, S) + idx_all = idx_all.reshape(world, B, S) + + # Winning rank per (B, S): argmax over the world axis. + best_rank = np.argmax(val_all, axis=0) # (B, S) + # Gather the local index chosen by the winning rank. + chosen_local = np.take_along_axis(idx_all, np.expand_dims(best_rank, 0), axis=0)[ + 0 + ] # (B, S) + global_id = best_rank * vocab_per_device + chosen_local + return global_id.astype(np.int32) diff --git a/examples/models/gpt_oss/eagle/speculate.py b/examples/models/gpt_oss/eagle/speculate.py new file mode 100644 index 0000000..78ac588 --- /dev/null +++ b/examples/models/gpt_oss/eagle/speculate.py @@ -0,0 +1,351 @@ +"""P-EAGLE speculative decoding for gpt-oss. + +Orchestrates target + drafter: + + 1. Target prefill on the prompt, capturing the 3 EAGLE-3 tap-layer hidden + states and the first real next token. + 2. Drafter proposes K tokens in one parallel forward pass from those hidden + states + the last accepted token. + 3. Target verifies the K candidates in ONE multi-token forward pass (seq_len = + K+1: the last accepted token followed by the K drafts), capturing fresh + tap-layer hidden states and the target's greedy token at each position. + 4. Accept the longest prefix of drafts that matches the target's greedy tokens; + append the target's correction token. Advance the KV write position by the + number of accepted tokens + 1. + +Greedy verification makes KV rollback implicit: rejected speculative KV entries +are simply overwritten by the next verify pass (which re-reads from the accepted +position), and the causal mask never lets a query attend past its own position. + +Run (from the gpt_oss/ directory, with eagle/ on PYTHONPATH): + torchrun --nproc-per-node $TP eagle/speculate.py \ + --target-checkpoint ./tmp_gpt-oss-20b \ + --draft-checkpoint ./tmp_p-eagle \ + --model /home/ubuntu/models/gpt-oss-20b \ + --draft-model /home/ubuntu/models/GPT-OSS-20B-P-EAGLE \ + -n 200 -k 7 "The capital of France is" +""" + +import argparse +import os +import sys +import time + +import numpy as np +import torch +import torch.distributed as dist + +# Both the base gpt_oss package and the eagle subpackage have flat `config.py` / +# `kernels/` modules. We put the base dir (gpt_oss/) FIRST on sys.path so the base +# flat modules win, and import everything eagle-specific as the `eagle.*` package +# (which can't collide). Works when run as `eagle/speculate.py` from gpt_oss/. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_BASE = os.path.dirname(_HERE) +# torchrun puts the script's own dir (eagle/) on sys.path[0], which would shadow +# the base flat modules (config.py, kernels/) with eagle's. Drop it and put the +# base dir first; eagle code is reached via the `eagle.*` package from _BASE. +sys.path[:] = [p for p in sys.path if os.path.abspath(p or ".") != _HERE] +if _BASE not in sys.path: + sys.path.insert(0, _BASE) + +from config import Config, get_config # noqa: E402 (base config; _BASE wins) +from eagle.config import get_eagle_config # noqa: E402 +from eagle.drafter_model import DrafterModel # noqa: E402 +from kernels.transformer_layer import transformer_layer # noqa: E402 (base) +from nkipy.runtime import DeviceKernel, DeviceTensor # noqa: E402 +from safetensors.torch import load_file # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 +from utils import print_log # noqa: E402 + +import gpt_oss as base # noqa: E402 (base model) + + +def _resolve_eos_ids(model_name, tokenizer): + try: + from transformers import GenerationConfig + + eos = GenerationConfig.from_pretrained(model_name).eos_token_id + except Exception: + eos = getattr(tokenizer, "eos_token_id", None) + if eos is None: + return set() + return set(eos) if isinstance(eos, (list, tuple)) else {eos} + + +class SpeculativeGptOss(base.GptOssModel): + """Target model + verify kernels for speculative decoding.""" + + def __init__(self, weights, config, num_draft_tokens): + self.num_draft_tokens = num_draft_tokens + super().__init__(weights, config) + self._prepare_verify_kernels() + + def _prepare_verify_kernels(self): + """Compile a seq_len=(K+1) layer kernel per attention type, and a + per-position verify-argmax kernel.""" + from eagle.kernels.verify import verify_argmax + + S = self.num_draft_tokens + 1 + cfg = self.config + x_verify = DeviceTensor.from_numpy( + np.empty((cfg.max_batch_size, S, cfg.hidden_size), dtype=cfg.dtype), + "x_verify", + ) + start_pos = DeviceTensor.from_numpy(np.empty((1), dtype=np.int32), "vs_pos") + + self.kernel_verify_layer = {} + for sliding in (False, True): + sw = cfg.sliding_window if sliding else None + lt = self.layer_tensors[0] + self.kernel_verify_layer[sliding] = DeviceKernel.compile_and_load( + transformer_layer, + name=f"verify_layer_{'sw' if sliding else 'full'}", + x=x_verify, + start_pos=start_pos, + qkv_weight=lt["qkv_weight"], + qkv_bias=lt["qkv_bias"], + o_weight=lt["o_weight"], + o_bias=lt["o_bias"], + sinks=lt["sinks"], + input_weight=lt["input_weight"], + post_attention_weight=lt["post_attention_weight"], + router_weight=lt["router_weight"], + router_bias=lt["router_bias"], + gate_up_weight=lt["gate_up_weight"], + gate_up_bias=lt["gate_up_bias"], + down_weight=lt["down_weight"], + down_bias=lt["down_bias"], + cache_k=lt["cache_k"], + cache_v=lt["cache_v"], + configs=cfg, + sliding_window=sw, + build_dir=base.BUILD_DIR, + additional_compiler_args=cfg.additional_compiler_args_nkipy, + ) + + self.kernel_verify_argmax = DeviceKernel.compile_and_load( + verify_argmax, + name="verify_argmax", + h=x_verify, + norm_weight=self.norm_weight, + lm_head_weight=self.lm_head_weight, + configs=cfg, + build_dir=base.BUILD_DIR, + additional_compiler_args=cfg.additional_compiler_args_nkipy, + ) + + def verify(self, tokens, start_pos): + """Run K+1 candidate tokens through the target at absolute `start_pos`. + + Args: + tokens: list of K+1 token ids (last accepted token + K drafts). + start_pos: absolute position of tokens[0] in the sequence. + Returns: + (target_tokens, aux) where target_tokens is a length-(K+1) list of the + target's greedy next token at each position, and aux is the list of 3 + captured tap-layer hidden states (each (B, K+1, H), host tensors). + """ + cfg = self.config + S = len(tokens) + h = DeviceTensor.from_torch( + self.tok_embedding[torch.tensor(tokens)].reshape(1, S, cfg.hidden_size), + "verify_h", + ) + pos = DeviceTensor.from_numpy( + np.array([start_pos], dtype=np.int32), "verify_pos" + ) + + aux = [] + for i in range(cfg.num_layers): + lt = self.layer_tensors[i] + kernel = self.kernel_verify_layer[cfg.is_sliding(i)] + inputs = {key: lt[key] for key in base._LAYER_WEIGHT_KEYS} + inputs["x"] = h + inputs["start_pos"] = pos + inputs["cache_k.must_alias_input"] = lt["cache_k"] + inputs["cache_v.must_alias_input"] = lt["cache_v"] + kernel( + inputs=inputs, + outputs={ + "output0": h, + "cache_k": lt["cache_k"], + "cache_v": lt["cache_v"], + }, + ) + if cfg.aux_layers is not None and i in cfg.aux_layers: + aux.append(h.torch().clone()) + + target_ids = DeviceTensor.from_numpy( + np.empty((1, S), dtype=np.int32), "tgt_ids" + ) + self.kernel_verify_argmax( + inputs={ + "h": h, + "norm_weight": self.norm_weight, + "lm_head_weight": self.lm_head_weight, + }, + outputs={"output0": target_ids}, + ) + return target_ids.torch().reshape(S).tolist(), aux + + +def _stack_aux(aux_list): + """Concatenate the 3 tap-layer hiddens along the feature axis: (B, S, 3H).""" + return torch.cat([a.to(torch.bfloat16) for a in aux_list], dim=-1) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--max-new-tokens", type=int, default=128) + parser.add_argument("-k", "--num-draft-tokens", type=int, default=7) + parser.add_argument("prompt", nargs="?", default="The capital of France is") + parser.add_argument("--target-checkpoint", default="./tmp_gpt-oss-20b") + parser.add_argument("--draft-checkpoint", default="./tmp_p-eagle") + parser.add_argument("--model", default="/home/ubuntu/models/gpt-oss-20b") + parser.add_argument( + "--draft-model", default="/home/ubuntu/models/GPT-OSS-20B-P-EAGLE" + ) + args = parser.parse_args() + + os.environ["TOKENIZERS_PARALLELISM"] = "true" + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["NEURON_RT_ROOT_COMM_ID"] = "localhost:61239" + dist.init_process_group() + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + os.environ["NEURON_RT_VISIBLE_CORES"] = str(dist.get_rank()) + + K = args.num_draft_tokens + tokenizer = AutoTokenizer.from_pretrained(args.model) + input_ids = tokenizer(args.prompt, return_tensors="np")["input_ids"] + prompt_len = input_ids.shape[1] + eos_ids = _resolve_eos_ids(args.model, tokenizer) + + # Target config + aux taps. + config = get_config(args.model, prompt_len, args.max_new_tokens) + config.aux_layers = Config.default_aux_layers(config.num_layers) + + print_log("Loading target weights") + tgt_shard = os.path.join( + args.target_checkpoint, f"shard_{dist.get_rank()}.safetensors" + ) + target = SpeculativeGptOss(load_file(tgt_shard, device="cpu"), config, K) + + # Drafter (replicated on every rank). + print_log("Loading drafter weights") + ecfg = get_eagle_config( + args.draft_model, + target_hidden_size=config.hidden_size, + num_draft_tokens=K, + max_seq_len=config.max_seq_len, + ) + draft_weights = load_file( + os.path.join(args.draft_checkpoint, "drafter.safetensors"), device="cpu" + ) + drafter = DrafterModel(draft_weights, ecfg, base.BUILD_DIR) + + # ── Prefill the target on the prompt ── + dist.barrier() + t0 = time.time() + hidden, aux = target.run_prefill(input_ids, capture_aux=True) + + # First token: greedy sample from the prefill hidden (reuse the base kernel). + first_id_dev = DeviceTensor.from_numpy(np.array([[0]], dtype=np.uint32), "first_id") + target.kernel_cte_greedy_sampling( + inputs={ + "h": hidden, + "norm_weight": target.norm_weight, + "lm_head_weight": target.lm_head_weight, + }, + outputs={"output0": first_id_dev}, + ) + next_id = int(first_id_dev.torch().reshape(-1)[0]) + + generated = [next_id] + cur_pos = prompt_len # absolute position of `next_id` + + # The prefill's aux gives hidden states at position (prompt_len - 1), but the + # drafter needs the hidden state at position `cur_pos` (after `next_id` has been + # processed through the target). Run a single-token decode on `next_id` to get + # the hidden state at `cur_pos`, capturing aux along the way. + seed_h = DeviceTensor.from_torch( + target.tok_embedding[torch.tensor([[next_id]])], "seed_h" + ) + seed_pos = DeviceTensor.from_numpy(np.array([cur_pos], dtype=np.int32), "seed_pos") + seed_aux = [] + for i in range(config.num_layers): + target._run_layer("tkg", i, seed_h, seed_pos) + if config.aux_layers is not None and i in config.aux_layers: + seed_aux.append(seed_h.torch().clone()) + last_aux3 = _stack_aux([a[:, 0:1, :] for a in seed_aux]) + cur_pos += 1 + + ttft = time.time() - t0 + n_accepted_total = 0 + n_steps = 0 + + if dist.get_rank() == 0: + print(f"\n{args.prompt}", end="") + print(tokenizer.decode([next_id]), end="") + sys.stdout.flush() + + t_decode = time.time() + while len(generated) < args.max_new_tokens: + # 1) Draft K tokens from the last accepted token + its tapped hiddens. + drafts = drafter.draft(last_aux3, generated[-1]) + + # 2) Verify: feed [last_token, drafts...] at absolute cur_pos. + cand = [generated[-1]] + drafts # length K+1 + target_ids, aux = target.verify(cand, cur_pos) + + # 3) Accept the longest matching prefix (greedy). + accepted = [] + for i in range(K): + accepted.append(target_ids[i]) # target's correction/confirmation + if drafts[i] != target_ids[i]: + break + else: + # all K matched; the (K+1)-th target token is a free bonus token + accepted.append(target_ids[K]) + + n_accepted = len(accepted) + n_accepted_total += n_accepted + n_steps += 1 + + # Emit accepted tokens (truncate at max + stop on EOS). + stop = False + for tok in accepted: + if len(generated) >= args.max_new_tokens: + stop = True + break + generated.append(tok) + if dist.get_rank() == 0: + print(tokenizer.decode([tok]), end="") + sys.stdout.flush() + if tok in eos_ids: + stop = True + break + if stop: + break + + # 4) Advance. The accepted tokens occupy cur_pos .. cur_pos+n_accepted-1 + # (already written to the KV cache by verify). The tapped hiddens for the + # last accepted token seed the next draft. + last_kept_index = n_accepted - 1 # index into the verify window + last_aux3 = _stack_aux( + [a[:, last_kept_index : last_kept_index + 1, :] for a in aux] + ) + cur_pos += n_accepted + + decode_time = time.time() - t_decode + if dist.get_rank() == 0: + n_new = len(generated) + accept_len = n_accepted_total / max(n_steps, 1) + print(f"\n\nTime to first token: {ttft:.2f}s") + print(f"Generated {n_new} tokens in {n_steps} verify steps") + print(f"Mean acceptance length: {accept_len:.2f} (K={K})") + print(f"Decode tokens/sec: {n_new / max(decode_time, 1e-6):.2f}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/gpt_oss/eagle/tensor_preparation.py b/examples/models/gpt_oss/eagle/tensor_preparation.py new file mode 100644 index 0000000..74d1619 --- /dev/null +++ b/examples/models/gpt_oss/eagle/tensor_preparation.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +"""Convert the P-EAGLE drafter checkpoint into a replicated safetensors file. + +The drafter is small (4 layers, ~3.6 GB bf16) and not quantized. Unlike the +TP-sharded target, the drafter is **replicated** on every rank: the target's +captured aux hidden states are already all-reduced (full) on every rank, so each +rank can run the identical full drafter forward and produce identical draft +tokens with no extra collectives. Prep is therefore just a transpose into the +``x @ W`` form the kernels expect. + +Output (single file ``drafter.safetensors``): + shared: embed_tokens, fc_weight, mask_hidden, norm_weight, lm_head_weight, d2t + midlayer (fusion): q/k/v/o_proj, input_weight, hidden_norm_weight, + post_attention_weight, gate/up/down_proj + layers.{i} (plain, i=1..N-1): same minus hidden_norm_weight +""" + +import argparse +import os + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from transformers import AutoConfig + + +def _attn_block(get, prefix, dtype): + """Transpose one attention block into fused x@W form (no sharding).""" + return { + "q_proj": get(f"{prefix}.self_attn.q_proj.weight").to(dtype).T.contiguous(), + "k_proj": get(f"{prefix}.self_attn.k_proj.weight").to(dtype).T.contiguous(), + "v_proj": get(f"{prefix}.self_attn.v_proj.weight").to(dtype).T.contiguous(), + "o_proj": get(f"{prefix}.self_attn.o_proj.weight").to(dtype).T.contiguous(), + } + + +def _mlp_block(get, prefix, dtype): + return { + "gate_proj": get(f"{prefix}.mlp.gate_proj.weight").to(dtype).T.contiguous(), + "up_proj": get(f"{prefix}.mlp.up_proj.weight").to(dtype).T.contiguous(), + "down_proj": get(f"{prefix}.mlp.down_proj.weight").to(dtype).T.contiguous(), + } + + +def build(get, config, dtype): + n_layers = config.num_hidden_layers + out = {} + + # Shared tensors. + out["embed_tokens"] = get("embed_tokens.weight").to(dtype) + out["fc_weight"] = get("fc.weight").to(dtype).T.contiguous() # (3H, H) -> x@W + out["mask_hidden"] = get("mask_hidden").to(dtype).reshape(1, -1).contiguous() + out["norm_weight"] = get("norm.weight").to(dtype) + out["lm_head_weight"] = ( + get("lm_head.weight").to(dtype).T.contiguous() + ) # (H, draft_vocab) + out["d2t"] = get("d2t").to(torch.int32) + + # Fusion midlayer (layer 0). + out.update( + {f"midlayer.{k}": v for k, v in _attn_block(get, "midlayer", dtype).items()} + ) + out.update( + {f"midlayer.{k}": v for k, v in _mlp_block(get, "midlayer", dtype).items()} + ) + out["midlayer.input_weight"] = get("midlayer.input_layernorm.weight").to(dtype) + out["midlayer.hidden_norm_weight"] = get("midlayer.hidden_norm.weight").to(dtype) + out["midlayer.post_attention_weight"] = get( + "midlayer.post_attention_layernorm.weight" + ).to(dtype) + + # Plain layers 1..N-1. + for i in range(1, n_layers): + p = f"layers.{i}" + out.update({f"{p}.{k}": v for k, v in _attn_block(get, p, dtype).items()}) + out.update({f"{p}.{k}": v for k, v in _mlp_block(get, p, dtype).items()}) + out[f"{p}.input_weight"] = get(f"{p}.input_layernorm.weight").to(dtype) + out[f"{p}.post_attention_weight"] = get( + f"{p}.post_attention_layernorm.weight" + ).to(dtype) + + return {k: v.contiguous() for k, v in out.items()} + + +def prepare(model_name, output_dir, dtype=torch.bfloat16): + os.makedirs(output_dir, exist_ok=True) + config = AutoConfig.from_pretrained(model_name) + handle = safe_open( + os.path.join(model_name, "model.safetensors"), framework="pt", device="cpu" + ) + + def get(key): + return handle.get_tensor(key) + + print(f"[1/1] Converting drafter `{model_name}` (replicated)...") + shard = build(get, config, dtype) + path = os.path.join(output_dir, "drafter.safetensors") + save_file(shard, path) + print(f" - wrote {path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert the P-EAGLE drafter into a replicated bf16 safetensors." + ) + parser.add_argument( + "--model-name", required=True, help="path to P-EAGLE checkpoint" + ) + parser.add_argument("--output-dir", default="tmp_p-eagle") + parser.add_argument("--dtype", choices=["f32", "bf16"], default="bf16") + args = parser.parse_args() + dtype = {"f32": torch.float32, "bf16": torch.bfloat16}[args.dtype] + prepare(args.model_name, args.output_dir, dtype=dtype) diff --git a/examples/models/gpt_oss/gpt_oss.py b/examples/models/gpt_oss/gpt_oss.py new file mode 100644 index 0000000..ded3184 --- /dev/null +++ b/examples/models/gpt_oss/gpt_oss.py @@ -0,0 +1,444 @@ +import argparse +import os +import sys +import time + +import numpy as np +import torch +import torch.distributed as dist +from config import Config, get_config +from kernels.sampling import greedy_sampling, greedy_sampling_with_embedding +from kernels.transformer_layer import transformer_layer +from nkipy.runtime import DeviceKernel, DeviceTensor +from safetensors.torch import load_file +from transformers import AutoTokenizer +from utils import print_log + +# Absolute path: the compiler chdir's into the per-kernel build dir, so a +# relative build dir would double up and the HLO module wouldn't be found. +BUILD_DIR = os.path.abspath("./build") +USE_NKI_RMSNORM = True + +# weight names carried per layer (everything except the kv cache) +_LAYER_WEIGHT_KEYS = [ + "qkv_weight", + "qkv_bias", + "o_weight", + "o_bias", + "sinks", + "input_weight", + "post_attention_weight", + "router_weight", + "router_bias", + "gate_up_weight", + "gate_up_bias", + "down_weight", + "down_bias", +] + + +class GptOssModel: + def __init__(self, model_weights, config: Config): + """Initialize the model with weights and configuration.""" + self.config = config + self.tok_embedding = model_weights.get("tok_embedding") + + # kernels keyed by (phase, sliding) -> compiled DeviceKernel + self.kernel_layer = {} + self.kernel_cte_greedy_sampling = None + self.kernel_cte_greedy_sampling_embed = None + self.kernel_tkg_greedy_sampling = None + self.kernel_tkg_greedy_sampling_embed = None + + self.norm_weight = None + self.lm_head_weight = None + self.tok_embedding_device = None + + self._prepare_tensors(model_weights) + self._prepare_kernels() + + def _prepare_tensors(self, weights): + t = time.time() + print_log("Preparing Tensors") + + n_local_kv_heads = max(1, self.config.num_kv_heads // dist.get_world_size()) + + cache_shape = ( + self.config.max_batch_size, + self.config.max_seq_len, + n_local_kv_heads, + self.config.head_dim, + ) + cache_k = np.zeros(cache_shape, dtype=self.config.dtype) + cache_v = np.zeros(cache_shape, dtype=self.config.dtype) + + # Per-layer device tensors. + self.layer_tensors = [] + for layer_id in range(self.config.num_layers): + lt = {} + for key in _LAYER_WEIGHT_KEYS: + w = weights.get(f"layers.{layer_id}.{key}") + lt[key] = DeviceTensor.from_torch(w, f"{key}_L{layer_id}") + lt["cache_k"] = DeviceTensor.from_numpy(cache_k, f"cache_k_L{layer_id}") + lt["cache_v"] = DeviceTensor.from_numpy(cache_v, f"cache_v_L{layer_id}") + self.layer_tensors.append(lt) + + self.norm_weight = DeviceTensor.from_torch( + weights.get("norm_weight"), "norm_weight" + ) + self.lm_head_weight = DeviceTensor.from_torch( + weights.get("lm_head_weight"), "lm_head_weight" + ) + self.tok_embedding_device = DeviceTensor.from_torch( + self.tok_embedding, "tok_embedding" + ) + + print_log(f"--> Finished Preparing Tensors in {time.time() - t:.2f}s") + + def _compile_layer(self, name, x, start_pos, sliding_window): + lt = self.layer_tensors[0] + return DeviceKernel.compile_and_load( + transformer_layer, + name=name, + x=x, + start_pos=start_pos, + qkv_weight=lt["qkv_weight"], + qkv_bias=lt["qkv_bias"], + o_weight=lt["o_weight"], + o_bias=lt["o_bias"], + sinks=lt["sinks"], + input_weight=lt["input_weight"], + post_attention_weight=lt["post_attention_weight"], + router_weight=lt["router_weight"], + router_bias=lt["router_bias"], + gate_up_weight=lt["gate_up_weight"], + gate_up_bias=lt["gate_up_bias"], + down_weight=lt["down_weight"], + down_bias=lt["down_bias"], + cache_k=lt["cache_k"], + cache_v=lt["cache_v"], + configs=self.config, + sliding_window=sliding_window, + build_dir=BUILD_DIR, + additional_compiler_args=self.config.additional_compiler_args_nkipy, + ) + + def _prepare_kernels(self): + t = time.time() + print_log("Preparing kernels") + + x_context = DeviceTensor.from_numpy( + np.empty( + ( + self.config.max_batch_size, + self.config.context_len, + self.config.hidden_size, + ), + dtype=self.config.dtype, + ), + "x_context", + ) + x_token = DeviceTensor.from_numpy( + np.empty( + (self.config.max_batch_size, 1, self.config.hidden_size), + dtype=self.config.dtype, + ), + "x_token", + ) + start_pos = DeviceTensor.from_numpy(np.empty((1), dtype=np.int32), "start_pos") + + # gpt-oss alternates sliding-window and full attention per layer, so we + # compile one transformer-layer kernel per (phase, attention-type) pair. + for sliding in (False, True): + sw = self.config.sliding_window if sliding else None + self.kernel_layer[("cte", sliding)] = self._compile_layer( + f"cte_layer_{'sw' if sliding else 'full'}", x_context, None, sw + ) + self.kernel_layer[("tkg", sliding)] = self._compile_layer( + f"tkg_layer_{'sw' if sliding else 'full'}", x_token, start_pos, sw + ) + + common = dict( + norm_weight=self.norm_weight, + lm_head_weight=self.lm_head_weight, + configs=self.config, + use_nki_rmsnorm=USE_NKI_RMSNORM, + build_dir=BUILD_DIR, + additional_compiler_args=self.config.additional_compiler_args_nkipy, + ) + self.kernel_cte_greedy_sampling = DeviceKernel.compile_and_load( + greedy_sampling, name="cte_greedy_sampling", h=x_context, **common + ) + self.kernel_tkg_greedy_sampling = DeviceKernel.compile_and_load( + greedy_sampling, name="tkg_greedy_sampling", h=x_token, **common + ) + self.kernel_cte_greedy_sampling_embed = DeviceKernel.compile_and_load( + greedy_sampling_with_embedding, + name="cte_greedy_sampling_embed", + h=x_context, + tok_embedding=self.tok_embedding_device, + **common, + ) + self.kernel_tkg_greedy_sampling_embed = DeviceKernel.compile_and_load( + greedy_sampling_with_embedding, + name="tkg_greedy_sampling_embed", + h=x_token, + tok_embedding=self.tok_embedding_device, + **common, + ) + + print_log( + f"--> Finished Kernel Compilation and Loading in {time.time() - t:.2f}s" + ) + + def _run_layer(self, phase, i, hidden_states, start_pos): + lt = self.layer_tensors[i] + kernel = self.kernel_layer[(phase, self.config.is_sliding(i))] + inputs = {key: lt[key] for key in _LAYER_WEIGHT_KEYS} + inputs["x"] = hidden_states + inputs["cache_k.must_alias_input"] = lt["cache_k"] + inputs["cache_v.must_alias_input"] = lt["cache_v"] + if phase == "tkg": + inputs["start_pos"] = start_pos + kernel( + inputs=inputs, + outputs={ + "output0": hidden_states, + "cache_k": lt["cache_k"], + "cache_v": lt["cache_v"], + }, + ) + + def run_prefill(self, input_ids, capture_aux=False): + """Run the context-encoding (prefill) layer stack. + + Args: + input_ids: prompt token ids, shape (B, L). + capture_aux: when True, also return the residual-stream hidden states + produced by the EAGLE-3 tap layers (``config.aux_layers``), in + low->mid->high order. Each is a host torch tensor of shape + (B, L, hidden_size). Used to seed the speculative drafter. + + Returns: + (hidden_states, aux) where hidden_states is the final-layer device + tensor and aux is a list of captured host tensors (empty unless + capture_aux and aux_layers are set). + """ + hidden_states = DeviceTensor.from_torch( + self.tok_embedding[input_ids], "hidden_states" + ) + + aux_layers = self.config.aux_layers if capture_aux else None + aux = [] + for i in range(self.config.num_layers): + self._run_layer("cte", i, hidden_states, None) + if aux_layers is not None and i in aux_layers: + aux.append(hidden_states.torch().clone()) + + return hidden_states, aux + + def generate(self, input_ids, double_buffering=True): + """Run inference and generate tokens.""" + hidden_states, _ = self.run_prefill(input_ids, capture_aux=False) + + if double_buffering: + yield from self._generate_double_buffered(hidden_states) + else: + yield from self._generate_baseline(hidden_states) + + def _run_tkg_layers(self, hidden_states, start_pos): + for i in range(self.config.num_layers): + self._run_layer("tkg", i, hidden_states, start_pos) + + # ── Double-buffered decode (fused sampling + on-device embedding) ────── + + def _generate_double_buffered(self, hidden_states): + context_len = self.config.context_len + B = self.config.max_batch_size + + next_id_bufs = [ + DeviceTensor.from_numpy(np.array([[0]], dtype=np.uint32), "next_id_0"), + DeviceTensor.from_numpy(np.array([[0]], dtype=np.uint32), "next_id_1"), + ] + decode_hidden = DeviceTensor.from_numpy( + np.empty((B, 1, self.config.hidden_size), dtype=self.config.dtype), + "decode_hidden", + ) + + cur_buf = 0 + self.kernel_cte_greedy_sampling_embed( + inputs={ + "h": hidden_states, + "norm_weight": self.norm_weight, + "lm_head_weight": self.lm_head_weight, + "tok_embedding": self.tok_embedding_device, + }, + outputs={"output0": next_id_bufs[cur_buf], "output1": decode_hidden}, + ) + + for pos in range(context_len, context_len + self.config.max_new_tokens): + prev_buf = cur_buf + cur_buf = 1 - cur_buf + + t_start_pos = DeviceTensor.from_numpy(np.array([pos], dtype=np.int32)) + self._run_tkg_layers(decode_hidden, t_start_pos) + + self.kernel_tkg_greedy_sampling_embed( + inputs={ + "h": decode_hidden, + "norm_weight": self.norm_weight, + "lm_head_weight": self.lm_head_weight, + "tok_embedding": self.tok_embedding_device, + }, + outputs={"output0": next_id_bufs[cur_buf], "output1": decode_hidden}, + ) + + next_id_torch = ( + next_id_bufs[prev_buf].torch().reshape(B, 1).to(dtype=torch.int) + ) + yield next_id_torch + + next_id_torch = next_id_bufs[cur_buf].torch().reshape(B, 1).to(dtype=torch.int) + yield next_id_torch + + # ── Baseline decode (host embedding lookup, no double buffering) ────── + + def _generate_baseline(self, hidden_states): + context_len = self.config.context_len + B = self.config.max_batch_size + + next_id = DeviceTensor.from_numpy(np.array([[0]], dtype=np.uint32), "next_id") + + self.kernel_cte_greedy_sampling( + inputs={ + "h": hidden_states, + "norm_weight": self.norm_weight, + "lm_head_weight": self.lm_head_weight, + }, + outputs={"output0": next_id}, + ) + next_id_torch = next_id.torch().reshape(B, 1).to(dtype=torch.int) + yield next_id_torch + + t_start_pos = DeviceTensor.from_numpy( + np.array([context_len], dtype=np.int32), "start_pos" + ) + hidden_states = DeviceTensor.from_torch( + self.tok_embedding[next_id_torch], "h0/res1" + ) + + for pos in range(context_len, context_len + self.config.max_new_tokens): + t_start_pos.write_from_numpy(np.array([pos], dtype=np.int32)) + hidden_states.write_from_torch(self.tok_embedding[next_id_torch]) + + self._run_tkg_layers(hidden_states, t_start_pos) + + self.kernel_tkg_greedy_sampling( + inputs={ + "h": hidden_states, + "norm_weight": self.norm_weight, + "lm_head_weight": self.lm_head_weight, + }, + outputs={"output0": next_id}, + ) + + next_id_torch = next_id.torch().reshape(B, 1).to(dtype=torch.int) + yield next_id_torch + + +def _resolve_eos_ids(model_name, tokenizer): + """Collect stop token ids from the generation config (falls back to tokenizer).""" + try: + from transformers import GenerationConfig + + gen = GenerationConfig.from_pretrained(model_name) + eos = gen.eos_token_id + except Exception: + eos = getattr(tokenizer, "eos_token_id", None) + if eos is None: + return set() + return set(eos) if isinstance(eos, (list, tuple)) else {eos} + + +def load_model(args): + """Initialize distributed env, load weights, and build a GptOssModel.""" + os.environ["TOKENIZERS_PARALLELISM"] = "true" + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["NEURON_RT_ROOT_COMM_ID"] = "localhost:61239" + + dist.init_process_group() + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + os.environ["NEURON_RT_VISIBLE_CORES"] = str(dist.get_rank()) + + tokenizer = AutoTokenizer.from_pretrained(args.model) + model_inputs = tokenizer(args.prompt, return_tensors="np") + input_ids = model_inputs["input_ids"] + config = get_config(args.model, input_ids.shape[1], args.max_new_tokens) + args.eos_ids = _resolve_eos_ids(args.model, tokenizer) + + print_log("Loading Model Weights") + shard_path = os.path.join(args.checkpoint, f"shard_{dist.get_rank()}.safetensors") + weights = load_file(shard_path, device="cpu") + + double_buffering = getattr(args, "double_buffering", True) + model = GptOssModel(weights, config) + + start = time.time() + print_log("Warming model") + t = 0 + for _ in model.generate(input_ids, double_buffering=double_buffering): + if t == 1: + break + t += 1 + print_log(f"--> Finished warming the model in {time.time() - start:.2f}s") + + return model, input_ids, tokenizer + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--max-new-tokens", type=int, default=16) + parser.add_argument("prompt", nargs="?", default="The capital of France is") + parser.add_argument("--checkpoint", default="./tmp_gpt-oss-20b") + parser.add_argument("--model", default="openai/gpt-oss-20b") + parser.add_argument( + "--no-double-buffering", + action="store_true", + help="Disable fused embedding + double-buffered decoding (for perf comparison)", + ) + args = parser.parse_args() + args.double_buffering = not args.no_double_buffering + + model, input_ids, tokenizer = load_model(args) + + dist.barrier() + start = time.time() + t = 0 + first_token_time = start + if dist.get_rank() == 0: + print(f"\n{args.prompt}", end="") + eos_ids = getattr(args, "eos_ids", set()) + for id in model.generate(input_ids, double_buffering=args.double_buffering): + if t == 0: + first_token_time = time.time() + t += 1 + output_id = id[0].tolist() + if output_id[-1] in eos_ids: + print_log("Found special/EOS token, stop early") + break + if dist.get_rank() == 0: + print(tokenizer.decode(output_id), end="") + sys.stdout.flush() + + end_time = time.time() + ttft = first_token_time - start + decoding_time = max(end_time - first_token_time, 1e-6) + tokens_per_second = t / decoding_time + if dist.get_rank() == 0: + print(f"\nTime to first token: {ttft:.2f}s") + print(f"Decoding tokens per second: {tokens_per_second:.2f}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/gpt_oss/kernels/__init__.py b/examples/models/gpt_oss/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/models/gpt_oss/kernels/attention.py b/examples/models/gpt_oss/kernels/attention.py new file mode 100644 index 0000000..34e1be1 --- /dev/null +++ b/examples/models/gpt_oss/kernels/attention.py @@ -0,0 +1,170 @@ +from typing import Optional + +import neuronxcc.nki.language as nl +import nkipy.core.typing as nt +import nkipy.distributed.collectives as cc +import numpy as np +import torch.distributed as dist +from nkipy.core import tensor_apis + +from .rope import apply_rotary_emb_kernel, compute_cos_sin_cache +from .softmax import softmax_kernel + + +def repeat_kv_kernel(x, n_rep: int): + """Repeat key-value heads for grouped-query attention.""" + if n_rep == 1: + return x + z = np.repeat(x, n_rep, axis=2) + return z + + +def attention_kernel( + x, + qkv_weight, + qkv_bias, + sinks, + rope_inv_freq, + rope_attention_scaling, + n_heads, + head_dim, + n_kv_heads, + cache_k, + cache_v, + start_pos: Optional[nt.tensor], + o_weight, + o_bias, + sliding_window: Optional[int], +): + """Unified attention kernel for gpt-oss. + + Differences from a plain GQA attention: + * QKV and output projections carry biases. + * No QK RMSNorm. + * Per-head attention sinks: a learned logit per head is concatenated to the + attention scores before softmax, then dropped afterwards. + * Sliding-window masking on the layers configured for it. + * YaRN RoPE (cos/sin baked from precomputed inverse frequencies). + + When start_pos is None: prefill mode (process full context). + When start_pos is provided: decode mode (process single token). + """ + is_prefill = start_pos is None + batch_size, seq_len, _ = x.shape + + n_local_heads = n_heads // dist.get_world_size() + assert n_local_heads > 0, f"n_local_heads {n_local_heads} is not greater than 0" + n_local_kv_heads = max(1, n_kv_heads // dist.get_world_size()) + n_rep = n_local_heads // n_local_kv_heads + + # QKV projection (+ bias). GQA: KV's head count differs from Q's. + split_axis = x.ndim - 1 + split0 = n_local_heads * head_dim + split1 = split0 + n_local_kv_heads * head_dim + splits = [split0, split1] + qkv = np.matmul(x, qkv_weight) + qkv_bias + xq, xk, xv = np.split(qkv, splits, axis=split_axis) + + xq = xq.reshape(batch_size, seq_len, n_local_heads, head_dim) + xk = xk.reshape(batch_size, seq_len, n_local_kv_heads, head_dim) + xv = xv.reshape(batch_size, seq_len, n_local_kv_heads, head_dim) + + # RoPE (YaRN) + max_seq_len = cache_k.shape[1] + freqs_cos, freqs_sin = compute_cos_sin_cache( + rope_inv_freq, max_seq_len, rope_attention_scaling, dtype=nl.bfloat16 + ) + if is_prefill: + freqs_cos = freqs_cos[0:seq_len] + freqs_sin = freqs_sin[0:seq_len] + else: + # Decode (seq_len==1) and speculative verify (seq_len==K+1) both write at a + # runtime offset. The absolute positions of the query tokens are + # start_pos + [0, 1, ..., seq_len-1]; promote comptime numpy arrays to + # runtime tensors so they can be gathered with that runtime index. + query_pos = start_pos + np.arange(seq_len, dtype=np.int32) + freqs_cos = tensor_apis.constant(freqs_cos) + freqs_sin = tensor_apis.constant(freqs_sin) + freqs_cos = freqs_cos[query_pos] + freqs_sin = freqs_sin[query_pos] + xq, xk = apply_rotary_emb_kernel(xq, xk, freqs_cos, freqs_sin) + + # KV cache update + if is_prefill: + cache_k[:, :seq_len] = xk + cache_v[:, :seq_len] = xv + else: + # Scatter the seq_len new K/V rows into the cache at their absolute + # positions (one row for decode, K+1 contiguous rows for verify). + cache_k[:, query_pos] = xk + cache_v[:, query_pos] = xv + + # GQA: repeat KV heads + keys = repeat_kv_kernel(cache_k, n_rep) + values = repeat_kv_kernel(cache_v, n_rep) + + # Transpose for attention: BSHD -> BHSD + xq = xq.transpose(0, 2, 1, 3) + keys = keys.transpose(0, 2, 1, 3) + values = values.transpose(0, 2, 1, 3) + + # Attention scores: BHSD @ BHDS -> BHSS + k_seq_len = keys.shape[2] + scores = (xq @ keys.transpose(0, 1, 3, 2)) / np.float32(np.sqrt(head_dim)) + scores = scores.astype(nl.bfloat16) + + # Causal (+ optional sliding-window) mask, computed at compile time. + NEG = -100000.0 + full_mask = np.triu(np.ones((k_seq_len, k_seq_len)) * NEG, k=1) + if sliding_window is not None: + # Disallow attending further back than `sliding_window` tokens: mask + # positions where (query_pos - key_pos) >= sliding_window. + window_mask = np.tril(np.ones((k_seq_len, k_seq_len)) * NEG, k=-sliding_window) + full_mask = full_mask + window_mask + causal_mask = full_mask.astype(scores.dtype) + causal_mask = tensor_apis.constant(causal_mask) + + if is_prefill: + scores = scores + np.expand_dims(causal_mask[:seq_len, :k_seq_len], axis=[0, 1]) + else: + # Gather one mask row per query token. For verify (seq_len>1) this yields a + # block-causal mask: query at start_pos+i attends to keys <= start_pos+i + # (plus the sliding-window limit, already baked into causal_mask). + scores = scores + np.expand_dims( + causal_mask[query_pos, :k_seq_len], axis=[0, 1] + ) + + # Attention sinks: concatenate a per-head learned logit as an extra "key" + # column, softmax over (keys + sink), then drop the sink probability. + # sinks: [n_local_heads] -> broadcast to [B, H, S, 1] + sink_col = np.reshape(sinks, (1, n_local_heads, 1, 1)) + sink_col = np.broadcast_to( + sink_col, (batch_size, n_local_heads, seq_len, 1) + ).astype(scores.dtype) + scores = np.concatenate([scores, sink_col], axis=-1) + + attention_weights = softmax_kernel(scores) + # Drop the sink column before applying to values. + attention_weights = attention_weights[:, :, :, :k_seq_len] + + # Apply attention to values: BHSS @ BHSD -> BHSD + output = attention_weights @ values + + # Transpose back: BHSD -> BSHD + output = output.transpose(0, 2, 1, 3) + output = output.reshape(batch_size, seq_len, -1) + + # Output projection (+ bias). Bias is replicated across ranks, so apply only + # on rank 0 to avoid double-counting after the all-reduce. + output_to_be_reduced = np.matmul(output, o_weight) + if dist.get_rank() == 0: + output_to_be_reduced = output_to_be_reduced + o_bias + + # All-reduce for tensor parallelism + output = cc.all_reduce( + output_to_be_reduced, + replica_groups=[list(range(dist.get_world_size()))], + reduce_op=np.add, + ) + + return output diff --git a/examples/models/gpt_oss/kernels/feedforward.py b/examples/models/gpt_oss/kernels/feedforward.py new file mode 100644 index 0000000..ce80dd5 --- /dev/null +++ b/examples/models/gpt_oss/kernels/feedforward.py @@ -0,0 +1,34 @@ +import numpy as np + + +def clamped_swiglu(gate, up, alpha, limit): + """gpt-oss clamped SwiGLU gating. + + Mirrors GptOssExperts._apply_gate: + gate = clamp(gate, max=limit) + up = clamp(up, -limit, limit) + glu = gate * sigmoid(alpha * gate) + out = (up + 1) * glu + """ + gate = np.minimum(gate, limit) + up = np.clip(up, -limit, limit) + glu = gate * (1.0 / (1.0 + np.exp(-alpha * gate))) + return (up + 1.0) * glu + + +def feedforward_kernel( + x, gate_up_weight, gate_up_bias, down_weight, down_bias, alpha, limit +): + """Single-expert feed-forward with clamped SwiGLU and biases. + + `gate_up_weight`/`gate_up_bias` are pre-arranged at weight-prep time so the + gate half comes first and the up half second (de-interleaved from the HF + layout), letting us split in half here. + """ + mm_gup = np.matmul(x, gate_up_weight) + gate_up_bias + + xg, x_up = np.split(mm_gup, 2, axis=-1) + + gated = clamped_swiglu(xg, x_up, alpha, limit) + + return np.matmul(gated, down_weight) + down_bias diff --git a/examples/models/gpt_oss/kernels/rmsnorm.py b/examples/models/gpt_oss/kernels/rmsnorm.py new file mode 100644 index 0000000..60a2d07 --- /dev/null +++ b/examples/models/gpt_oss/kernels/rmsnorm.py @@ -0,0 +1,21 @@ +import numpy as np + + +def rmsnorm_kernel( + x, + weight, + eps: float, + compute_dtype=np.float32, # reduce numerical error +): + original_dtype = x.dtype + x = x.astype(compute_dtype) + weight = weight.astype(compute_dtype) + z = np.square(x) + z = np.mean(z, axis=-1, keepdims=True) + + z = (z + eps).astype(x.dtype) + z = x / np.sqrt(z) + + res = z * weight + res = res.astype(original_dtype) + return res diff --git a/examples/models/gpt_oss/kernels/rope.py b/examples/models/gpt_oss/kernels/rope.py new file mode 100644 index 0000000..78a69ea --- /dev/null +++ b/examples/models/gpt_oss/kernels/rope.py @@ -0,0 +1,55 @@ +import numpy as np + + +def compute_cos_sin_cache( + inv_freq, max_seq_len, attention_scaling=1.0, dtype=np.float32 +): + """Compute cosine/sine cache for RoPE from precomputed inverse frequencies. + + gpt-oss uses YaRN scaling, so `inv_freq` (length head_dim/2) and the + `attention_scaling` post-multiplier are computed once on the host from the HF + config and passed in here. + + Comptime: this runs on constant arguments only, so under Trainium compilation + the resulting arrays bake into the HLO graph as constants. On CPU it is plain + numpy. + """ + inv_freq = np.asarray(inv_freq, dtype=np.float32) + t = np.arange(max_seq_len, dtype=np.float32) + freqs = np.outer(t, inv_freq) + + cos = (np.cos(freqs, dtype=np.float32) * attention_scaling).astype(dtype) + sin = (np.sin(freqs, dtype=np.float32) * attention_scaling).astype(dtype) + return cos, sin + + +def apply_rotary_emb_kernel(xq, xk, freqs_cos, freqs_sin): + """Apply rotary position embedding to query and key tensors. + + Matches gpt-oss `_apply_rotary_emb`: split each vector into first/second + halves and rotate, without interleaving. + """ + # Reshape `freqs_cos` and `freqs_sin` for broadcasting over (B, S, H, D/2). + freqs_cos = np.expand_dims(freqs_cos, axis=(0, 2)) + freqs_sin = np.expand_dims(freqs_sin, axis=(0, 2)) + + # Split the hidden states into two halves + half_h = xq.shape[-1] // 2 + xq0 = xq[:, :, :, :half_h] + xq1 = xq[:, :, :, half_h:] + + xk0 = xk[:, :, :, :half_h] + xk1 = xk[:, :, :, half_h:] + + # Apply rotary embedding between first and second halves + xq_out_0 = xq0 * freqs_cos - xq1 * freqs_sin + xq_out_1 = xq1 * freqs_cos + xq0 * freqs_sin + + xk_out_0 = xk0 * freqs_cos - xk1 * freqs_sin + xk_out_1 = xk1 * freqs_cos + xk0 * freqs_sin + + # Concatenate the results back together to form the final output + xq_out = np.concatenate([xq_out_0, xq_out_1], axis=-1) + xk_out = np.concatenate([xk_out_0, xk_out_1], axis=-1) + + return xq_out, xk_out diff --git a/examples/models/gpt_oss/kernels/sampling.py b/examples/models/gpt_oss/kernels/sampling.py new file mode 100644 index 0000000..851a1a9 --- /dev/null +++ b/examples/models/gpt_oss/kernels/sampling.py @@ -0,0 +1,175 @@ +import nki +import nki.isa as nisa +import nki.language as nl +import nkipy.distributed.collectives as cc +import numpy as np +import torch.distributed as dist + +# Import config from parent directory +from config import Config +from nkipy.core import ( + nki_op, # noqa: F401, make sure monkey patch is applied + tensor_apis, +) + +# Import kernels from the kernels directory +from .rmsnorm import rmsnorm_kernel + + +def stream_shuffle_broadcast(src, dst): + dst_npar = dst.shape[0] + free_dim = dst.shape[1] + shuffle_mask = [0] * 32 + + assert dst_npar % 32 == 0 + for i in range(dst_npar // 32): + nisa.nc_stream_shuffle( + src=src[0:1, :], + dst=dst[i * 32 : (i + 1) * 32, 0:free_dim], + shuffle_mask=shuffle_mask, + ) + + +@nki.jit +def nki_rmsnorm_kernel(input_tensor, weight, eps): + """ + RMSNorm NKI kernel - based on AWS official tutorial pattern. + Migrated to NKI Beta 2 API. + + Args: + input_tensor: Input tensor [batch*seq_len, hidden_size] + weight: RMSNorm weight parameter [hidden_size] + eps: Small epsilon for numerical stability + + Returns: + output: Normalized tensor with same shape as input + """ + MAX_P = 128 + + output = nl.ndarray( + input_tensor.shape, dtype=input_tensor.dtype, buffer=nl.shared_hbm + ) + assert input_tensor.shape[1] == weight.shape[0] + + num_rows = input_tensor.shape[0] + hidden_size = input_tensor.shape[1] + num_chunks = (num_rows + MAX_P - 1) // MAX_P + + # Load RMSNorm weight once into SBUF, reused by all rows + g_tile = nl.ndarray((1, hidden_size), dtype=weight.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_tile[0:1, 0:hidden_size], + src=weight.reshape((1, hidden_size))[0:1, 0:hidden_size], + ) + + for i in nl.affine_range(num_chunks): + p_start = i * MAX_P + valid_rows = min(MAX_P, num_rows - p_start) + + # Load valid rows from HBM (padded partitions are unused) + a = nl.ndarray((MAX_P, hidden_size), dtype=input_tensor.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=a[0:valid_rows, 0:hidden_size], + src=input_tensor[p_start : p_start + valid_rows, 0:hidden_size], + ) + + # a^2 -> t (reused below as normalized output) + t = nl.ndarray((MAX_P, hidden_size), dtype=input_tensor.dtype, buffer=nl.sbuf) + nisa.tensor_tensor(dst=t, data1=a, data2=a, op=nl.multiply) + + # sum(a^2) + sq_sum = nl.ndarray((MAX_P, 1), dtype=nl.float32, buffer=nl.psum) + nisa.tensor_reduce(dst=sq_sum, data=t, op=nl.add, axis=1) + + # rsqrt(mean(a^2) + eps), in-place + s = nl.ndarray((MAX_P, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=s, + data=sq_sum, + op0=nl.multiply, + operand0=1.0 / hidden_size, + op1=nl.add, + operand1=eps, + ) + nisa.activation(dst=s, data=s, op=nl.rsqrt) + + # a * rsqrt -> t + nisa.tensor_scalar(dst=t, data=a, operand0=s, op0=nl.multiply) + + # Broadcast weight and multiply + g_bcast = nl.ndarray((MAX_P, hidden_size), dtype=g_tile.dtype, buffer=nl.sbuf) + stream_shuffle_broadcast(g_tile, g_bcast) + nisa.tensor_tensor(dst=t, data1=t, data2=g_bcast, op=nl.multiply) + + # Store only valid rows back to HBM + nisa.dma_copy( + dst=output[p_start : p_start + valid_rows, 0:hidden_size], + src=t[0:valid_rows, 0:hidden_size], + ) + + return output + + +def greedy_sampling( + h, norm_weight, lm_head_weight, configs: Config, use_nki_rmsnorm=False +): + """Greedy sampling kernel for token generation.""" + + B, S, H = h.shape + # Note: this is just for showing how to use use a NKI kernel inside NKIPy + if use_nki_rmsnorm: + h = h.reshape(-1, H) # batch*seq_len, hidden_size + h = nki_rmsnorm_kernel(h, norm_weight, configs.norm_eps) + h = h.reshape(B, S, H) + else: + h = rmsnorm_kernel(h, norm_weight, configs.norm_eps) + + logits = h[:, -1, :] @ lm_head_weight + + logits, next_id = tensor_apis.topk(logits, k=1, axis=1) + logits_all = cc.all_gather( + logits, all_gather_dim=1, replica_groups=[list(range(dist.get_world_size()))] + ) + next_id_all = cc.all_gather( + next_id, all_gather_dim=1, replica_groups=[list(range(dist.get_world_size()))] + ) + + _, top_index = tensor_apis.topk(logits_all, k=1, axis=1) + final_next_id = np.empty_like(next_id) + + vocab_per_device = lm_head_weight.shape[1] + for b in range(configs.max_batch_size): + device_idx = top_index[b] + local_idx = next_id_all[b, device_idx] + global_idx = device_idx * vocab_per_device + local_idx + final_next_id[b] = global_idx + + return final_next_id + + +def greedy_sampling_with_embedding( + h, + norm_weight, + lm_head_weight, + tok_embedding, + configs: Config, + use_nki_rmsnorm=False, +): + """Greedy sampling with on-device embedding lookup for double buffering. + + Fuses token selection and embedding lookup into a single device kernel, + eliminating the host round-trip (D2H token ID -> host embedding lookup -> H2D embedding) + that would otherwise block each decode iteration. + + Returns: + (final_next_id, embedded): The selected token ID and its embedding, both on device. + """ + final_next_id = greedy_sampling( + h, norm_weight, lm_head_weight, configs, use_nki_rmsnorm + ) + + # On-device embedding lookup: gather rows from tok_embedding using selected token IDs + # final_next_id shape: (B, 1) -> embedded shape: (B, 1, hidden_size) + embedded = np.take(tok_embedding, final_next_id, axis=0) + + return final_next_id, embedded diff --git a/examples/models/gpt_oss/kernels/softmax.py b/examples/models/gpt_oss/kernels/softmax.py new file mode 100644 index 0000000..2f9a948 --- /dev/null +++ b/examples/models/gpt_oss/kernels/softmax.py @@ -0,0 +1,6 @@ +import numpy as np + + +def softmax_kernel(x): + exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) + return exp_x / np.sum(exp_x, axis=-1, keepdims=True) diff --git a/examples/models/gpt_oss/kernels/transformer_layer.py b/examples/models/gpt_oss/kernels/transformer_layer.py new file mode 100644 index 0000000..a9d62b3 --- /dev/null +++ b/examples/models/gpt_oss/kernels/transformer_layer.py @@ -0,0 +1,131 @@ +import nkipy.distributed.collectives as cc +import numpy as np +import torch.distributed as dist + +# Import config from parent directory +from config import Config +from nkipy.core import tensor_apis + +# Import kernels from the kernels directory +from .attention import attention_kernel +from .feedforward import feedforward_kernel +from .rmsnorm import rmsnorm_kernel +from .softmax import softmax_kernel + + +def transformer_layer( + x, + start_pos, + # attention weights + qkv_weight, + qkv_bias, + o_weight, + o_bias, + sinks, + input_weight, + post_attention_weight, + # moe weights + router_weight, + router_bias, + gate_up_weight, + gate_up_bias, + down_weight, + down_bias, + # kv cache + cache_k, + cache_v, + configs: Config, + sliding_window=None, +): + """Single gpt-oss transformer layer for prefill and decode. + + When start_pos is None: prefill mode (process full context) + When start_pos is provided: decode mode (process single token) + """ + # Apply input RMSNorm + norm_x = rmsnorm_kernel(x, input_weight, configs.norm_eps) + + # Attention + h1 = attention_kernel( + norm_x, + qkv_weight, + qkv_bias, + sinks, + configs.rope_inv_freq, + configs.rope_attention_scaling, + configs.num_heads, + configs.head_dim, + configs.num_kv_heads, + cache_k, + cache_v, + start_pos=start_pos, + o_weight=o_weight, + o_bias=o_bias, + sliding_window=sliding_window, + ) + + # Residual connection after attention + z = x + h1 + + # Get shapes + B, L, D = z.shape + top_k = configs.num_experts_per_tok + + # Apply RMSNorm before MoE + norm_z = rmsnorm_kernel(z, post_attention_weight, configs.norm_eps) + + # Router logits [B, L, n_experts] (with bias) + router_logits = np.matmul(norm_z, router_weight) + router_bias + + # Initialize output tensor + output = np.empty_like(z) + + # Process each batch separately + for b in range(B): + # Process each token in the sequence + for t in range(L): + # Get token input [D] + token_input = norm_z[b, t, :] + + # Get token routing logits [n_experts] + token_logits = router_logits[b, t] + + # gpt-oss routing: pick top-k on raw logits, then softmax over the + # selected logits (NOT softmax-then-topk). + top_k_logits, top_k_indices = tensor_apis.topk(token_logits, k=top_k) + top_k_weights = softmax_kernel(top_k_logits) + + # Process through each selected expert + token_output = tensor_apis.zeros((D), dtype=output.dtype) + + for e in range(top_k): + expert_idx = top_k_indices[e] + weight = top_k_weights[e] + + expert_output = feedforward_kernel( + token_input, + gate_up_weight[expert_idx], + gate_up_bias[expert_idx], + down_weight[expert_idx], + down_bias[expert_idx], + configs.swiglu_alpha, + configs.swiglu_limit, + ) + + token_output += weight * expert_output + + output[b, t] = token_output + + # All-reduce for tensor parallelism. Expert weights (gate_up/down) are sharded + # along the intermediate dimension, so the per-rank partial down-projection + # outputs sum to the full result. The down_proj bias is replicated and added + # inside feedforward_kernel; to avoid counting it world_size times after the + # reduction, weight prep zeroes down_bias on all ranks except rank 0. + output = cc.all_reduce( + output, replica_groups=[list(range(dist.get_world_size()))], reduce_op=np.add + ) + + # Add residual connection + final_output = z + output + + return final_output diff --git a/examples/models/gpt_oss/tensor_preparation.py b/examples/models/gpt_oss/tensor_preparation.py new file mode 100644 index 0000000..e2bb249 --- /dev/null +++ b/examples/models/gpt_oss/tensor_preparation.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +"""Download/convert gpt-oss weights into per-rank safetensors shards for NKIPy. + +gpt-oss ships its experts MXFP4-quantized (``*_blocks`` / ``*_scales``). We +dequantize them to bf16 at prep time (the chosen, simplest approach) so the NKI +kernels operate purely on bf16, mirroring the Qwen3 example. + +Each rank's shard contains, per layer: + * ``qkv_weight`` / ``qkv_bias`` - fused Q,K,V projection (x @ W form) + * ``o_weight`` / ``o_bias`` - output projection + * ``sinks`` - per-head attention sink logits + * ``input_weight`` / ``post_attention_weight`` - RMSNorm gains + * ``router_weight`` / ``router_bias`` + * ``gate_up_weight`` / ``gate_up_bias`` - de-interleaved [gate | up] + * ``down_weight`` / ``down_bias`` +plus shared ``tok_embedding``, ``norm_weight`` and ``lm_head_weight``. +""" + +import argparse +import json +import os + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from transformers import AutoConfig +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + + +class LazyWeights: + """Read tensors by name from a sharded safetensors checkpoint on demand.""" + + def __init__(self, model_dir): + index_path = os.path.join(model_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path) as f: + self.weight_map = json.load(f)["weight_map"] + else: + # Single-file checkpoint. + self.weight_map = None + self.model_dir = model_dir + self._handles = {} + self._single = os.path.join(model_dir, "model.safetensors") + + def _handle(self, key): + fname = self.weight_map[key] if self.weight_map else "model.safetensors" + path = os.path.join(self.model_dir, fname) + if path not in self._handles: + self._handles[path] = safe_open(path, framework="pt", device="cpu") + return self._handles[path] + + def get(self, key): + return self._handle(key).get_tensor(key) + + def has(self, key): + if self.weight_map is not None: + return key in self.weight_map + return key in self._handle(key).keys() + + +def chunk_rank(t, dim, rank, world_size): + return t.chunk(world_size, dim=dim)[rank].contiguous() + + +def shard_kv(weight, head_dim, n_kv_heads, rank, world_size): + """Shard a K/V projection (out=n_kv_heads*head_dim, in). + + When the per-rank slice would be smaller than a single head, replicate whole + heads across rank groups instead (mirrors the Qwen3 example). + """ + out_features = weight.shape[0] + if out_features // world_size >= head_dim and out_features % world_size == 0: + return chunk_rank(weight, 0, rank, world_size) + # Replicate: pick the kv head this rank maps to. + w = weight.reshape(n_kv_heads, head_dim, weight.shape[1]) + head_index = (n_kv_heads * rank) // world_size + return w[head_index].contiguous() + + +def build_shard(weights, config, rank, world_size, dtype): + n_layers = config.num_hidden_layers + n_kv_heads = config.num_key_value_heads + head_dim = config.head_dim + + out = {} + + # Shared (non-sharded) tensors. + out["norm_weight"] = weights.get("model.norm.weight").to(dtype) + out["tok_embedding"] = weights.get("model.embed_tokens.weight").to(dtype) + # lm_head: colwise shard along vocab, store transposed (hidden, vocab_local). + lm_head = weights.get("lm_head.weight").to(dtype) + out["lm_head_weight"] = chunk_rank(lm_head, 0, rank, world_size).T.contiguous() + + for layer_id in range(n_layers): + p = f"model.layers.{layer_id}" + + # ---- Attention projections (nn.Linear: y = x @ W.T) ---- + q_w = weights.get(f"{p}.self_attn.q_proj.weight").to(dtype) + k_w = weights.get(f"{p}.self_attn.k_proj.weight").to(dtype) + v_w = weights.get(f"{p}.self_attn.v_proj.weight").to(dtype) + o_w = weights.get(f"{p}.self_attn.o_proj.weight").to(dtype) + q_b = weights.get(f"{p}.self_attn.q_proj.bias").to(dtype) + k_b = weights.get(f"{p}.self_attn.k_proj.bias").to(dtype) + v_b = weights.get(f"{p}.self_attn.v_proj.bias").to(dtype) + o_b = weights.get(f"{p}.self_attn.o_proj.bias").to(dtype) + sinks = weights.get(f"{p}.self_attn.sinks").to(dtype) + + # Q: colwise (shard heads). K/V: shard heads w/ replication fallback. + q_w_s = chunk_rank(q_w, 0, rank, world_size) + q_b_s = chunk_rank(q_b, 0, rank, world_size) + k_w_s = shard_kv(k_w, head_dim, n_kv_heads, rank, world_size) + v_w_s = shard_kv(v_w, head_dim, n_kv_heads, rank, world_size) + k_b_s = shard_kv( + k_b.reshape(-1, 1), head_dim, n_kv_heads, rank, world_size + ).reshape(-1) + v_b_s = shard_kv( + v_b.reshape(-1, 1), head_dim, n_kv_heads, rank, world_size + ).reshape(-1) + + # Fuse into x @ W form: W stored as (hidden, out_local). + qkv_weight = torch.cat([q_w_s.T, k_w_s.T, v_w_s.T], dim=-1).contiguous() + qkv_bias = torch.cat([q_b_s, k_b_s, v_b_s], dim=-1).contiguous() + out[f"layers.{layer_id}.qkv_weight"] = qkv_weight + out[f"layers.{layer_id}.qkv_bias"] = qkv_bias + + # o_proj: rowwise (shard along input = heads*head_dim). + out[f"layers.{layer_id}.o_weight"] = chunk_rank( + o_w, 1, rank, world_size + ).T.contiguous() + # o_bias is replicated; keep it only on rank 0 (added once post all-reduce). + out[f"layers.{layer_id}.o_bias"] = ( + o_b if rank == 0 else torch.zeros_like(o_b) + ).contiguous() + + # sinks: per-head, shard along heads to match Q. + out[f"layers.{layer_id}.sinks"] = chunk_rank(sinks, 0, rank, world_size) + + # ---- RMSNorm gains (replicated) ---- + out[f"layers.{layer_id}.input_weight"] = weights.get( + f"{p}.input_layernorm.weight" + ).to(dtype) + out[f"layers.{layer_id}.post_attention_weight"] = weights.get( + f"{p}.post_attention_layernorm.weight" + ).to(dtype) + + # ---- Router (replicated) ---- + router_w = weights.get(f"{p}.mlp.router.weight").to( + dtype + ) # (n_experts, hidden) + router_b = weights.get(f"{p}.mlp.router.bias").to(dtype) + out[f"layers.{layer_id}.router_weight"] = ( + router_w.T.contiguous() + ) # (hidden, n_experts) + out[f"layers.{layer_id}.router_bias"] = router_b.contiguous() + + # ---- Experts (MXFP4 -> bf16, de-interleave gate/up, shard along inter) ---- + gu_blocks = weights.get(f"{p}.mlp.experts.gate_up_proj_blocks") + gu_scales = weights.get(f"{p}.mlp.experts.gate_up_proj_scales") + gu_bias = weights.get(f"{p}.mlp.experts.gate_up_proj_bias").to(dtype) + dn_blocks = weights.get(f"{p}.mlp.experts.down_proj_blocks") + dn_scales = weights.get(f"{p}.mlp.experts.down_proj_scales") + dn_bias = weights.get(f"{p}.mlp.experts.down_proj_bias").to(dtype) + + # gate_up: (E, hidden, 2*inter) with gate/up INTERLEAVED on last dim. + gate_up = convert_moe_packed_tensors(gu_blocks, gu_scales).to(dtype) + E, hidden, two_inter = gate_up.shape + gate_up = gate_up.reshape(E, hidden, two_inter // 2, 2) + gate = gate_up[..., 0] # (E, hidden, inter) + up = gate_up[..., 1] + gate = chunk_rank(gate, 2, rank, world_size) + up = chunk_rank(up, 2, rank, world_size) + # Store [gate | up] so the kernel can split in half. + out[f"layers.{layer_id}.gate_up_weight"] = torch.cat( + [gate, up], dim=-1 + ).contiguous() + + gu_bias = gu_bias.reshape(E, two_inter // 2, 2) + gate_b = chunk_rank(gu_bias[..., 0], 1, rank, world_size) + up_b = chunk_rank(gu_bias[..., 1], 1, rank, world_size) + out[f"layers.{layer_id}.gate_up_bias"] = torch.cat( + [gate_b, up_b], dim=-1 + ).contiguous() + + # down: (E, inter, hidden), shard along inter (dim 1). + down = convert_moe_packed_tensors(dn_blocks, dn_scales).to(dtype) + out[f"layers.{layer_id}.down_weight"] = chunk_rank(down, 1, rank, world_size) + # down bias replicated -> rank 0 only (added once post all-reduce). + out[f"layers.{layer_id}.down_bias"] = ( + dn_bias if rank == 0 else torch.zeros_like(dn_bias) + ).contiguous() + + return out + + +def preshard_model(model_name, output_dir, world_size, dtype=torch.bfloat16): + os.makedirs(output_dir, exist_ok=True) + config = AutoConfig.from_pretrained(model_name) + weights = LazyWeights(model_name) + + print( + f"[1/2] Sharding `{model_name}` into {world_size} ranks (dequantizing MXFP4)..." + ) + for rank in range(world_size): + shard = build_shard(weights, config, rank, world_size, dtype) + path = os.path.join(output_dir, f"shard_{rank}.safetensors") + save_file(shard, path) + print(f" - wrote {path}") + del shard + + print(f"[2/2] Done! {world_size} shards saved in {output_dir}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Pre-shard gpt-oss into per-rank bf16 safetensors for NKIPy." + ) + parser.add_argument( + "--model-name", required=True, help="HF repo or local path to gpt-oss" + ) + parser.add_argument("--output-dir", default="sharded_gpt_oss") + parser.add_argument( + "--world-size", type=int, required=True, help="Number of tensor-parallel ranks" + ) + parser.add_argument( + "--dtype", choices=["f32", "bf16"], default="bf16", help="Output dtype" + ) + args = parser.parse_args() + dtype = {"f32": torch.float32, "bf16": torch.bfloat16}[args.dtype] + + preshard_model(args.model_name, args.output_dir, args.world_size, dtype=dtype) diff --git a/examples/models/gpt_oss/test.sh b/examples/models/gpt_oss/test.sh new file mode 100644 index 0000000..d0f2759 --- /dev/null +++ b/examples/models/gpt_oss/test.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Test script for gpt-oss-20b on Trainium +# Usage: bash test.sh + +set -e + +echo "==========================================" +echo "gpt-oss-20b Test Script" +echo "==========================================" + +# Step 1: Clean compilation cache +echo "" +echo "[1/3] Cleaning compilation cache..." +rm -rf build/ 2>/dev/null || true +echo "✓ Cache cleaned" + +# Step 2: Check and prepare weights +echo "" +echo "[2/3] Checking weights..." + +MODEL="${MODEL:-openai/gpt-oss-20b}" +WEIGHTS_PATH="${WEIGHTS_PATH:-./tmp_gpt-oss-20b}" +TP_DEGREE="${TP_DEGREE:-4}" # Tensor parallelism + +if [ ! -d "$WEIGHTS_PATH" ]; then + echo "Weights not found. Converting (dequantizing MXFP4 to bf16)..." + python tensor_preparation.py --model-name "$MODEL" --world-size "$TP_DEGREE" --output-dir="$WEIGHTS_PATH" + echo "✓ Weights prepared" +else + echo "✓ Weights found at $WEIGHTS_PATH" +fi + +# Step 3: Run example +echo "" +echo "[3/3] Running gpt-oss inference..." +echo "==========================================" + +# Enable async to improve performance +export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=16 +torchrun --nproc-per-node "$TP_DEGREE" gpt_oss.py -n 500 --checkpoint "$WEIGHTS_PATH" --model "$MODEL" + +echo "" +echo "==========================================" +echo "✓ Test passed!" +echo "==========================================" diff --git a/examples/models/gpt_oss/utils.py b/examples/models/gpt_oss/utils.py new file mode 100644 index 0000000..ad9cacf --- /dev/null +++ b/examples/models/gpt_oss/utils.py @@ -0,0 +1,15 @@ +import sys + +import ml_dtypes +import numpy as np +import torch.distributed as dist + +bfloat16 = np.dtype(ml_dtypes.bfloat16) + + +def print_log(msg, rank_list=[0], verbose=0): + if not dist.is_initialized(): + print(msg) + elif dist.get_rank() in rank_list: + print(f"[RANK {dist.get_rank()}] {msg}") + sys.stdout.flush()