You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
gm.ckpts.SkipLoRA fails to restore pre-trained weights into a LoRA-wrapped Gemma4_E2B model. The checkpoint's flat ArrayImpl params don't match the LoRA wrapper's dict structure at MLP linear layers.
This blocks LoRA fine-tuning for all Gemma 4 models using the native gemma + kauldron stack.
Training failed: User-provided restore item and on-disk value metadata
tree structures do not match:
layer_9.mlp.linear:
- Target: <class 'dict'>
- Source: <class 'jaxlib._jax.ArrayImpl'>
If this mismatch is intentional, pass `partial_restore=True` to only
restore parameters found in `item`.
Note: partial_restore=True is not a valid parameter for LoadCheckpoint.
gm.ckpts.load_params(CheckpointPath.GEMMA4_E2B_IT) base load — OK (19.1 GiB, 65s)
gm.text.ChatSampler(model=..., params=...).chat(...) generation — OK
Kauldron trainer config construction — OK
Only SkipLoRA + LoadCheckpoint fails during trainer.train() init.
Variants tested (all fail with same error)
Variant
Result
Gemma4_E2B(tokens="batch.input")
Tree mismatch
Gemma4_E2B(tokens="batch.input", text_only=True)
Tree mismatch
Gemma4_E2B() (no tokens arg)
Cannot resolve required keys: ['tokens']
LoadCheckpoint(..., partial_restore=True)
unexpected keyword argument
Expected behavior
SkipLoRA should map the checkpoint's flat linear params to the LoRA wrapper's split structure (base + lora_a + lora_b), the same way it works for Gemma3_4B in examples/lora.py.
Notes
The examples/lora.py in this repo demonstrates the pattern with Gemma3_4B — it works. The issue is specific to Gemma 4, likely because Gemma 4's MLP layers have a different structure (possibly related to MoE or the new architecture) that SkipLoRA doesn't handle.
Description
gm.ckpts.SkipLoRAfails to restore pre-trained weights into a LoRA-wrappedGemma4_E2Bmodel. The checkpoint's flatArrayImplparams don't match the LoRA wrapper'sdictstructure at MLP linear layers.This blocks LoRA fine-tuning for all Gemma 4 models using the native gemma + kauldron stack.
Related issues
mainbranch incompatible with published PyPI versions of kauldron and dialog (workaround: install both from git main)model = gm.nn.Gemma4_E2B()#606 —dialog.FormatAttributeError when creatingGemma4_E2B(workaround: installdialogfrom git main)Reproduction
Error
Note:
partial_restore=Trueis not a valid parameter forLoadCheckpoint.Environment
model = gm.nn.Gemma4_E2B()#606)What works
gm.nn.Gemma4_E2B()model creation — OKgm.nn.LoRA(rank=8, model=...)wrapper — OKgm.ckpts.load_params(CheckpointPath.GEMMA4_E2B_IT)base load — OK (19.1 GiB, 65s)gm.text.ChatSampler(model=..., params=...).chat(...)generation — OKOnly
SkipLoRA+LoadCheckpointfails duringtrainer.train()init.Variants tested (all fail with same error)
Gemma4_E2B(tokens="batch.input")Gemma4_E2B(tokens="batch.input", text_only=True)Gemma4_E2B()(no tokens arg)Cannot resolve required keys: ['tokens']LoadCheckpoint(..., partial_restore=True)unexpected keyword argumentExpected behavior
SkipLoRAshould map the checkpoint's flat linear params to the LoRA wrapper's split structure (base + lora_a + lora_b), the same way it works forGemma3_4Binexamples/lora.py.Notes
The
examples/lora.pyin this repo demonstrates the pattern withGemma3_4B— it works. The issue is specific to Gemma 4, likely because Gemma 4's MLP layers have a different structure (possibly related to MoE or the new architecture) thatSkipLoRAdoesn't handle.