contrib: Gemma-4-26B-A4B-it port (MoE, TP=8, BF16)#172
Open
xniwangaws wants to merge 2 commits into
Open
Conversation
Port of google/gemma-4-26B-A4B-it (~25.2B/3.8B active MoE, 30 layers, 8 active / 128 total + 1 shared expert via parallel dense MLP). Validated on trn2.48xlarge (SDK 2.29): Stage 1 (DISABLE_MOE): compile 2.2 min, load 20.9s, NEFF 17 MB Stage 2 (MoE on): compile 19.7 min, load 29.1s, NEFF 297 MB Stage 3 (inference): TTFT 309.5 ms, TPOT 8.79 ms, 114 tok/s Borrows NKI flash attn (d=256 SWA + d=512), KV cache manager, softcap LM head, scaled embedding, and Gemma4 RMSNorm flavours from PR aws-neuron#106 (gemma-4-31B-IT). 26B-A4B-specific: MoE block with parallel dense MLP at decoder layer (HF source lines 1429-1441), dual-input MoE forward (router sees raw residual, experts see post_feedforward_layernorm_2(residual)), 128 experts with top-k=8 + per-expert-scale routing.
11/12 first-16-token match at 100% vs HF CPU bf16 (Gemma4ForConditionalGeneration,
transformers 5.10.0.dev0) using processor.apply_chat_template (enable_thinking=False
and =True) and processor.parse_response. The lone 87.5% (capital/thinking-on/sample)
is sampling-RNG divergence; greedy at the same setting is 100%. enable_thinking=True
exercises the multi-channel response path through MoE routing -- backends produce
identical greedy tokens and identical parse_response {role,thinking,content} dicts.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Contrib port of
google/gemma-4-26B-A4B-it— a text-only Gemma 4 sibling with a sparse-MoE FFN (~25.2B total / ~3.8B active per token, 30 layers, 8 active / 128 total + 1 shared dense expert). Companion to PR #106 (31B-IT, dense): same Gemma 4 attention/RoPE/softcap/scaled-embed stack, but with a parallel dense-MLP + MoE branch at every decoder layer.What's in this PR
contrib/models/gemma-4-26b-a4b-it/src/modeling_gemma4_neuron.py— text-only port: Gemma4 attention (heterogeneous SWA d=256 / global d=512), parallel-MLP+MoE decoder layer, FP32 router withper_expert_scale, softcapped LM head, KV cache manager with per-layer cache size, HF→Neuron state-dict converter (handles MoE weight stacking, K=V tying for global layers, q-norm pre-scaling).src/nki_flash_attn_d256_swa.py,src/nki_flash_attn_large_d.py— verbatim from PR Add Gemma 4 31B IT contrib model #106.src/ndxi_patch.py— verbatim from PR Add Gemma 4 31B IT contrib model #106 (one-line relative-import tweak so the file is self-contained insidesrc/).src/configuration_gemma4_neuron.py— split-out HF config dataclass for static parsing.scripts/smoke_compile.py,scripts/smoke_inference.py— Stage 1/2/3 runners.test/integration/test_model.py— pytest-friendly thin wrapper.README.md+DIFF_FROM_PR106.md— review companions; the diff doc enumerates exactly what is identical / adapted / new vs. PR Add Gemma 4 31B IT contrib model #106.Validation
trn2.48xlarge, SDK 2.29 (
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/, torch 2.9.1, NxDI 0.10.0), TP=8, BF16, seq_len=256, LNC=2."Hello, my name is"Known limitations / follow-ups
smoke_inference.pyfalls back to the rawtokenizersRust backend.seq_len > 256compile is a follow-up.layer_to_cache_size_mappingand theget_last_kv_windowpatch).Test plan