Skip to content

[WIP] Feat/gpt oss example#63

Draft
ymwangg wants to merge 6 commits into
mainfrom
feat/gpt-oss-example
Draft

[WIP] Feat/gpt oss example#63
ymwangg wants to merge 6 commits into
mainfrom
feat/gpt-oss-example

Conversation

@ymwangg

@ymwangg ymwangg commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Add p-eagle gpt-oss-20b example.

ymwangg added 6 commits June 25, 2026 13:38
Add a NKIPy example for OpenAI's gpt-oss MoE models (gpt-oss-20b / 120b),
mirroring the qwen3 example structure. The implementation is fully
config-driven, so both sizes share one codebase.

gpt-oss-specific handling:
- MXFP4 experts dequantized to bf16 at prep time
- interleaved gate/up de-interleaved at prep time
- clamped SwiGLU with gate_up/down biases
- per-head attention sinks + QKV/O biases (no QK-norm)
- alternating sliding-window / full attention (one kernel per type)
- YaRN RoPE (inv_freq precomputed from HF config)
- router with top-k-then-softmax and router bias

Validated against HF on trn2 (TP=4): every generated token matches HF's
argmax or a bf16-resolution tie.
Implements parallel-drafting P-EAGLE (arXiv 2602.01469) on top of the
gpt-oss base model for speculative decoding on Trainium.

Components added (examples/models/gpt_oss/eagle/):
- config.py: EagleConfig for the 4-layer P-EAGLE drafter (llama3 RoPE,
  fc fusion, mask_hidden/ptd_token_id, d2t vocab map)
- tensor_preparation.py: convert P-EAGLE checkpoint to x@W form (replicated)
- kernels/drafter.py: parallel-drafting forward - K tokens in one pass via
  NTP (real hidden) + MTP (mask_hidden) positions with cross-depth mask
- kernels/drafter_layer.py: EAGLE-3 fusion midlayer + plain Llama layers
- kernels/verify.py: multi-position greedy argmax for verification
- drafter_model.py: device-side drafter model + compile
- speculate.py: full speculation loop (prefill → draft → verify → accept)

Base model changes:
- config.py: added aux_layers config + default_aux_layers() for EAGLE-3 taps
- gpt_oss.py: run_prefill() now optionally captures pre-layer hidden states
  at the 3 EAGLE-3 tap layers (2, L/2, L-3)
- kernels/attention.py: generalized decode path to support seq_len>1 (for
  the multi-token verify pass) via query_pos = start_pos + arange(seq_len)

Status: functionally correct (lossless greedy output verified against HF).
Acceptance length is below the paper's reported ~3.3 — under investigation
(likely a hidden-state position/timing issue in the draft-verify loop seeding).
…yers

Switch aux capture to post-layer (output of tap layers 2/12/21) based on
HF validation showing the drafter predicts correctly with HF's hidden
states at hs[3]/hs[13]/hs[22] (output of layers 2/12/21).

Note: acceptance length remains low (~1.0) due to numerical divergence
between nkipy's Neuron-compiled target and the HF CPU reference the
drafter was trained against. The drafter kernel is mathematically correct
(validated against independent torch reference) and correctly predicts
the target when fed exact HF hidden states. The gap is an
implementation-coupling issue inherent to EAGLE-style speculation.
Key findings from the P-EAGLE paper (Figure 2, Figure 3, Section 3):

1. The drafter maintains its own KV cache across the full context
   (prompt + all accepted tokens). At each draft step, K positions
   attend to the FULL accumulated cache.

2. The attention mask is GROUP-CAUSAL: all K positions see the full
   cache (group 0), but within the K positions the NTP (group 1)
   and MTP (group 2+) positions use cross-depth causality — MTP
   positions cannot attend to positions at the same or later depth.

3. The NTP pair is (emb(t_n), hidden_after_processing_t_{n-1}),
   predicting t_{n+1}. The hidden is one step behind the embedding.

This commit adds:
- drafter_cpu.py: CPU reference drafter with full KV cache and
  standard causal attention (working infrastructure, mask needs
  the group-causal refinement for MTP positions)
- Fixes hidden state capture to post-layer (output of tap layers)
- Adds peagle_aux_layers config method

Status: KV cache infrastructure correct, still needs the group-causal
mask refinement for the MTP positions within the K-wide draft window.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant