Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions examples/models/gpt_oss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# gpt-oss on Trainium

A clean implementation of OpenAI's gpt-oss MoE models (e.g., `gpt-oss-20b`) for
AWS Trainium, built on NKIPy.

## Setup

``` sh
cd nkipy
uv sync --all-groups
source .venv/bin/activate
cd examples/models/gpt_oss
```

## Quickstart

`test.sh` handles weight preparation and runs a generation end-to-end:

``` sh
./test.sh
```

Or run generation directly (assumes weights are already prepared):

``` sh
WEIGHTS=./tmp_gpt-oss-20b
TP=4

torchrun --nproc-per-node $TP gpt_oss.py \
-n 500 --checkpoint $WEIGHTS --model openai/gpt-oss-20b \
"The capital of France is"
```

You can point `--model` at a local checkpoint directory too.

## Weight preparation

gpt-oss ships its experts **MXFP4-quantized** (`*_blocks` / `*_scales`). The prep
step dequantizes them to bf16 so the NKI kernels run purely in bf16, and it
shards every tensor for tensor parallelism:

``` sh
python tensor_preparation.py \
--model-name openai/gpt-oss-20b \
--world-size 4 \
--output-dir ./tmp_gpt-oss-20b
```

This writes `shard_{rank}.safetensors` files. Dequantized bf16 weights are
larger than the packed checkpoint (~40 GB total), so make sure you have disk
headroom.

## Architecture notes

gpt-oss differs from the Qwen3 MoE example in several ways, all handled here:

| Feature | Handling |
|---|---|
| MXFP4 experts | Dequantized to bf16 at prep time (`tensor_preparation.py`) |
| Interleaved gate/up | De-interleaved to `[gate \| up]` at prep time |
| Clamped SwiGLU | `(up+1) * gate*sigmoid(alpha*gate)` with `clamp(limit=7)` (`kernels/feedforward.py`) |
| Attention sinks | Per-head sink logit concatenated into softmax, then dropped (`kernels/attention.py`) |
| QKV / output bias | Carried through prep and added in the attention kernel |
| Sliding-window attention | Alternating sliding (window=128) / full causal layers; one kernel compiled per attention type |
| YaRN RoPE | `inv_freq` + attention-scaling precomputed from HF config (`config.py`) and baked into the cos/sin cache |
| Router | top-k on raw logits (+bias), then softmax over the selected logits |

## Files

| File | Purpose |
|---|---|
| `gpt_oss.py` | Model definition (`GptOssModel`) and text generation |
| `config.py` | Model configuration (incl. YaRN RoPE precompute) |
| `tensor_preparation.py` | Dequantize, reshape, and shard HF weights for TP |
| `test.sh` | Smoke test: prepares weights and runs generation |
| `kernels/` | Attention, feed-forward, RoPE, RMSNorm, softmax, sampling kernels |
Empty file.
85 changes: 85 additions & 0 deletions examples/models/gpt_oss/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from dataclasses import dataclass

import numpy as np
import torch.distributed as dist
from neuronxcc.nki.language import bfloat16
from transformers import AutoConfig
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS

# to control compiler_args
DTYPE = bfloat16


@dataclass
class Config:
hidden_size: int
num_heads: int
head_dim: int
num_kv_heads: int
num_layers: int
num_experts_per_tok: int
num_experts: int
# RoPE (YaRN) inverse frequencies and post-scaling, precomputed from HF.
rope_inv_freq: np.ndarray
rope_attention_scaling: float
# Per-layer attention type: "sliding_attention" or "full_attention".
layer_types: list
sliding_window: int
# Clamped-SwiGLU parameters (gpt-oss specific).
swiglu_alpha: float = 1.702
swiglu_limit: float = 7.0
context_len: int = None
max_new_tokens: int = None
max_batch_size: int = 1
norm_eps: float = 1e-5
intermediate_size: int = 2880
max_seq_len: int = 4096
dtype: np.dtype = DTYPE
additional_compiler_args_nkipy: str = "--lnc 1"
# Decoder-layer indices whose outputs are tapped for EAGLE-3 speculative
# decoding. None disables capture (the default, non-speculative path).
aux_layers: tuple = None

def is_sliding(self, layer_id: int) -> bool:
return self.layer_types[layer_id] == "sliding_attention"

@staticmethod
def default_aux_layers(num_layers: int) -> tuple:
"""EAGLE-3's standard low/mid/high decoder-layer taps (vLLM convention)."""
return (2, num_layers // 2, num_layers - 3)

@staticmethod
def peagle_aux_layers(num_layers: int) -> tuple:
"""P-EAGLE's tap layers (0-indexed, captures before each layer)."""
return (0, num_layers // 2, num_layers - 1)


def get_config(model_name, context_len, max_new_tokens):
hf_config = AutoConfig.from_pretrained(model_name)

# YaRN RoPE: precompute inverse frequencies + attention scaling factor once.
# These are constants (independent of runtime tensors), so we bake them into
# the kernel's cos/sin cache at compile time.
rope_init_fn = ROPE_INIT_FUNCTIONS[hf_config.rope_parameters["rope_type"]]
inv_freq, attention_scaling = rope_init_fn(hf_config, device=None)

config = Config(
hidden_size=hf_config.hidden_size,
intermediate_size=hf_config.intermediate_size // dist.get_world_size(),
num_heads=hf_config.num_attention_heads,
head_dim=hf_config.head_dim,
num_kv_heads=hf_config.num_key_value_heads,
norm_eps=hf_config.rms_norm_eps,
num_layers=hf_config.num_hidden_layers,
num_experts_per_tok=hf_config.num_experts_per_tok,
num_experts=hf_config.num_local_experts,
rope_inv_freq=np.asarray(inv_freq, dtype=np.float32),
rope_attention_scaling=float(attention_scaling),
layer_types=list(hf_config.layer_types),
sliding_window=hf_config.sliding_window,
swiglu_alpha=getattr(hf_config, "swiglu_alpha", 1.702),
swiglu_limit=hf_config.swiglu_limit,
context_len=context_len,
max_new_tokens=max_new_tokens,
)
return config
190 changes: 190 additions & 0 deletions examples/models/gpt_oss/eagle/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# P-EAGLE Speculative Decoding for gpt-oss on Trainium

Parallel-drafting speculative decoding using [P-EAGLE](https://arxiv.org/abs/2602.01469)
for the gpt-oss model family on AWS Trainium. Generates K draft tokens in a
**single forward pass** (not K sequential passes), then verifies them against the
target in one multi-token target forward.

## Setup

``` sh
cd nkipy
uv sync --all-groups
source .venv/bin/activate
cd examples/models/gpt_oss
```

## Quickstart

### 1. Prepare weights

The target model (gpt-oss-20b) must already be prepared (see `../README.md`):

``` sh
# Target (if not already done)
python tensor_preparation.py \
--model-name /path/to/gpt-oss-20b \
--world-size 4 --output-dir ./tmp_gpt-oss-20b

# Drafter (P-EAGLE, replicated on every rank — small, ~3.6 GB)
python eagle/tensor_preparation.py \
--model-name /path/to/GPT-OSS-20B-P-EAGLE \
--output-dir ./eagle/tmp_p-eagle
```

### 2. Run speculative decoding

``` sh
TP=4
torchrun --nproc-per-node $TP eagle/speculate.py \
--target-checkpoint ./tmp_gpt-oss-20b \
--draft-checkpoint ./eagle/tmp_p-eagle \
--model /path/to/gpt-oss-20b \
--draft-model /path/to/GPT-OSS-20B-P-EAGLE \
-n 256 -k 7 \
"Write a Python function that implements binary search."
```

Output includes acceptance metrics:

```
Time to first token: 0.6s
Generated 256 tokens in N verify steps
Mean acceptance length: X.XX (K=7)
Decode tokens/sec: XX.XX
```

## How it works

### Speculation loop

```
1. Target prefill on prompt → first token + 3 tapped hidden states
2. Seed: run first token through target decode → hidden states for drafter
3. Loop:
a. Drafter: K tokens in ONE parallel forward pass
b. Target verify: run [last_accepted, draft_0, ..., draft_{K-1}] through
target layers (seq_len = K+1) with block-causal mask
c. Accept: longest prefix where draft[i] == target_argmax[i]
d. Emit accepted tokens + bonus correction token
e. Advance KV cache position by (accepted + 1)
```

### P-EAGLE parallel drafting (K tokens in one pass)

Unlike autoregressive EAGLE which runs K sequential drafter passes, P-EAGLE
generates all K draft tokens simultaneously:

| Position | Embedding input | Hidden state input |
|----------|----------------|-------------------|
| 0 (NTP) | `embed(last_accepted_token)` | `fc(concat(aux_layer_2, aux_layer_12, aux_layer_21))` — real target hidden |
| 1..K-1 (MTP) | `embed(ptd_token_id)` — placeholder | `fc(mask_hidden)` — learnable shared hidden |

All K positions attend under a **cross-depth causal mask** (depth d sees depths
≤ d) through the EAGLE-3 fusion midlayer + 3 plain Llama decoder layers. Each
position's `lm_head` logit gives one draft token.

### Architecture details

The P-EAGLE drafter (`GPT-OSS-20B-P-EAGLE`, ~3.6 GB bf16):

| Component | Description |
|-----------|-------------|
| `fc` (8640→2880) | Fuses 3 target hidden states (layers 2, 12, 21 of 24-layer target) |
| `midlayer` | EAGLE-3 fusion decoder layer: attention takes 2×hidden (embed⊕hidden), has `hidden_norm` |
| `layers.1/2/3` | Plain Llama decoder layers (SiLU MLP, llama3 RoPE) |
| `mask_hidden` (1,1,8640) | Learnable shared hidden state for MTP positions |
| `ptd_token_id` = 201020 | Placeholder token whose embedding fills MTP positions |
| `d2t` / `t2d` | Draft↔target vocab mapping (identity for this checkpoint) |
| `lm_head` (2880→201088) | Full target vocab, replicated on every rank |

### Verification

The target verifies K+1 candidate tokens in a single multi-token forward pass:
- Runs the full gpt-oss decoder stack with `seq_len = K+1` at a runtime offset
- Uses absolute-position RoPE and a block-causal attention mask
- Writes K+1 new KV cache entries contiguously
- Produces per-position greedy argmax via cross-rank reduction

**Greedy acceptance makes KV rollback implicit**: rejected speculative entries are
overwritten by the next verify pass, and the causal mask prevents any query from
attending past its own position.

## Files

| File | Purpose |
|------|---------|
| `speculate.py` | Main entry: speculation loop orchestrating target + drafter |
| `config.py` | `EagleConfig` for the P-EAGLE drafter (llama3 RoPE, fc, mask_hidden, K) |
| `tensor_preparation.py` | Convert P-EAGLE checkpoint to x@W form (replicated, no TP) |
| `drafter_model.py` | Device-side drafter: loads weights, compiles kernel, runs draft |
| `kernels/drafter.py` | Parallel-drafting forward kernel (K tokens in one pass) |
| `kernels/drafter_layer.py` | EAGLE-3 fusion midlayer + plain Llama layers |
| `kernels/verify.py` | Multi-position greedy argmax for verification |
| `kernels/rope.py` | llama3 RoPE (different from target's YaRN RoPE) |
| `kernels/rmsnorm.py`, `softmax.py` | Leaf kernels (copied from base) |

## Validation

| What | Result |
|------|--------|
| Drafter kernel math | ✅ All 7 draft tokens match independent PyTorch reference |
| Speculation output | ✅ Lossless — output matches HF greedy baseline exactly |
| Drafter with HF hidden states | ✅ Draft[0] matches target greedy perfectly |
| Multi-token verify | ✅ Block-causal mask + KV scatter correct |

## Known limitation: acceptance length

The current CPU-side acceptance length is ~1.4 tokens/step (vs paper's ~3.7 for
GPT-OSS 20B at K=5). The NTP (depth 0) position works correctly — draft[0]
frequently matches the target's greedy. The MTP (depth 1+) positions
underperform, producing generic tokens instead of context-specific ones.

### What's been verified

- Drafter NTP produces the correct next token when given HF hidden states ✅
- Drafter KV cache is necessary and improves acceptance from 1.0 to 1.4 ✅
- The EAGLE shifted-token convention (input_ids shifted +1 vs hidden states)
matches vLLM's implementation ✅
- Hidden-state capture point (output of tap layers) matches what vLLM uses ✅
- The midlayer concat `[norm(embed), norm(hidden)]` → 2H → attention → H output
with H-wide residual matches vLLM's `Eagle3DecoderLayer` ✅

### vLLM reference (parallel_drafting)

Studied from the installed vLLM at `private-vllm-neuron/.venv`. Key findings:

1. vLLM's parallel drafting produces ALL K tokens in **one forward pass**: the
expanded input contains [shifted context tokens | bonus (next_token) | K-1
ptd_token positions]. All go through the model together with PagedAttention.

2. The `parallel_drafting_hidden_state_tensor` = `fc(mask_hidden)` (the fc-fused
mask_hidden at 2880 dim), placed at the MTP positions in the hidden_states
input to the model.

3. The Triton kernel `copy_and_expand_eagle_inputs_kernel` handles the layout:
positions are sequential (start_pos + j), parallel-draft slots get
`ptd_token_id` for input_ids and `parallel_drafting_hidden_state_tensor` for
hidden_states.

4. The model's `forward(input_ids, positions, hidden_states)` takes all three
as separate tensors of the same length. Only the midlayer (layer 0)
concatenates embeds with hidden to produce 2H; subsequent layers are standard.

### Remaining gap to investigate

The MTP positions have correct architecture but produce generic predictions. The
most likely remaining issue is an off-by-one in how the drafter's RoPE positions
map to the target's absolute positions during the KV-cached speculation loop.
vLLM assigns positions sequentially from the start of the context, and the
parallel-draft positions get positions immediately following the last valid token.
Our `drafter_cpu.py` does the same via `torch.arange(cache_len, cache_len + K)`.

### Path forward

1. Run the drafter via vLLM on GPU with this exact checkpoint and capture the
actual acceptance length (confirms the checkpoint quality ceiling)
2. If vLLM achieves ~3.7, the issue is in our inference loop (position/hidden
alignment during rollback+extend)
3. If vLLM also gets ~1.4, the checkpoint may be under-trained for this prompt
distribution (the paper evaluates on HumanEval/MT-Bench/GSM-8K specifically)
Empty file.
Loading