Skip to content

contrib: Gemma-4-26B-A4B-it port (MoE, TP=8, BF16)#172

Open
xniwangaws wants to merge 2 commits into
aws-neuron:mainfrom
xniwangaws:contrib_gemma_4_26b_a4b_it
Open

contrib: Gemma-4-26B-A4B-it port (MoE, TP=8, BF16)#172
xniwangaws wants to merge 2 commits into
aws-neuron:mainfrom
xniwangaws:contrib_gemma_4_26b_a4b_it

Conversation

@xniwangaws

Copy link
Copy Markdown

Summary

Contrib port of google/gemma-4-26B-A4B-it — a text-only Gemma 4 sibling with a sparse-MoE FFN (~25.2B total / ~3.8B active per token, 30 layers, 8 active / 128 total + 1 shared dense expert). Companion to PR #106 (31B-IT, dense): same Gemma 4 attention/RoPE/softcap/scaled-embed stack, but with a parallel dense-MLP + MoE branch at every decoder layer.

What's in this PR

  • contrib/models/gemma-4-26b-a4b-it/
    • src/modeling_gemma4_neuron.py — text-only port: Gemma4 attention (heterogeneous SWA d=256 / global d=512), parallel-MLP+MoE decoder layer, FP32 router with per_expert_scale, softcapped LM head, KV cache manager with per-layer cache size, HF→Neuron state-dict converter (handles MoE weight stacking, K=V tying for global layers, q-norm pre-scaling).
    • src/nki_flash_attn_d256_swa.py, src/nki_flash_attn_large_d.py — verbatim from PR Add Gemma 4 31B IT contrib model #106.
    • src/ndxi_patch.py — verbatim from PR Add Gemma 4 31B IT contrib model #106 (one-line relative-import tweak so the file is self-contained inside src/).
    • src/configuration_gemma4_neuron.py — split-out HF config dataclass for static parsing.
    • scripts/smoke_compile.py, scripts/smoke_inference.py — Stage 1/2/3 runners.
    • test/integration/test_model.py — pytest-friendly thin wrapper.
    • README.md + DIFF_FROM_PR106.md — review companions; the diff doc enumerates exactly what is identical / adapted / new vs. PR Add Gemma 4 31B IT contrib model #106.

Validation

trn2.48xlarge, SDK 2.29 (/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/, torch 2.9.1, NxDI 0.10.0), TP=8, BF16, seq_len=256, LNC=2.

Stage Result
1 — DISABLE_MOE compile + load PASS — compile 2.2 min, load 20.9 s, NEFF 17 MB
2 — MoE on compile + load PASS — compile 19.7 min, load 29.1 s, NEFF 297 MB
3 — Inference smoke PASS — TTFT 309.5 ms, TPOT 8.79 ms, 114 tok/s; greedy continuation of "Hello, my name is"

Known limitations / follow-ups

  • Smoke-test only: no token-match accuracy validation vs. HF reference yet (greedy + no chat template ⇒ Stage 3 output is the expected base-style repetition).
  • AutoTokenizer trips on Gemma 4's special-tokens list-vs-dict shape; smoke_inference.py falls back to the raw tokenizers Rust backend.
  • Multimodal towers (vision, audio) not ported — text-only. Use PR Add Gemma 4 31B IT contrib model #106 / PR Add Gemma-4 support #109 for VLM.
  • seq_len > 256 compile is a follow-up.
  • NxDI ≥ 0.10 required (per-layer layer_to_cache_size_mapping and the get_last_kv_window patch).

Test plan

  • Stage 1 (dense path): compile + load on trn2.48xlarge
  • Stage 2 (MoE on): compile + load on trn2.48xlarge
  • Stage 3: inference smoke (8 generated tokens, throughput ≥ 100 tok/s)
  • Token-match accuracy vs HF CPU reference (follow-up)
  • Longer seq_len compile (1024 / 2048) (follow-up)

Port of google/gemma-4-26B-A4B-it (~25.2B/3.8B active MoE, 30 layers,
8 active / 128 total + 1 shared expert via parallel dense MLP).

Validated on trn2.48xlarge (SDK 2.29):
  Stage 1 (DISABLE_MOE): compile 2.2 min, load 20.9s, NEFF 17 MB
  Stage 2 (MoE on):      compile 19.7 min, load 29.1s, NEFF 297 MB
  Stage 3 (inference):   TTFT 309.5 ms, TPOT 8.79 ms, 114 tok/s

Borrows NKI flash attn (d=256 SWA + d=512), KV cache manager, softcap
LM head, scaled embedding, and Gemma4 RMSNorm flavours from PR aws-neuron#106
(gemma-4-31B-IT).

26B-A4B-specific: MoE block with parallel dense MLP at decoder layer
(HF source lines 1429-1441), dual-input MoE forward (router sees raw
residual, experts see post_feedforward_layernorm_2(residual)), 128
experts with top-k=8 + per-expert-scale routing.
11/12 first-16-token match at 100% vs HF CPU bf16 (Gemma4ForConditionalGeneration,
transformers 5.10.0.dev0) using processor.apply_chat_template (enable_thinking=False
and =True) and processor.parse_response. The lone 87.5% (capital/thinking-on/sample)
is sampling-RNG divergence; greedy at the same setting is 100%. enable_thinking=True
exercises the multi-channel response path through MoE routing -- backends produce
identical greedy tokens and identical parse_response {role,thinking,content} dicts.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant