Skip to content

SkipLoRA checkpoint restore fails for Gemma4_E2B — tree structure mismatch #611

@stharrold

Description

@stharrold

Description

gm.ckpts.SkipLoRA fails to restore pre-trained weights into a LoRA-wrapped Gemma4_E2B model. The checkpoint's flat ArrayImpl params don't match the LoRA wrapper's dict structure at MLP linear layers.

This blocks LoRA fine-tuning for all Gemma 4 models using the native gemma + kauldron stack.

Related issues

Reproduction

from gemma import gm
from kauldron import kd
import optax

model = gm.nn.LoRA(
    rank=8,
    model=gm.nn.Gemma4_E2B(tokens="batch.input"),
)

init_transform = gm.ckpts.SkipLoRA(
    wrapped=gm.ckpts.LoadCheckpoint(
        path=gm.ckpts.CheckpointPath.GEMMA4_E2B_IT,
    )
)

tokenizer = gm.text.Gemma3Tokenizer()

trainer = kd.train.Trainer(
    seed=42,
    workdir="/tmp/test_workdir",
    train_ds=kd.data.py.Tfds(
        name="mtnt/en-fr",
        split="train",
        shuffle=True,
        num_epochs=None,
        batch_size=1,
        num_workers=1,
        transforms=[
            gm.data.Seq2SeqTask(
                in_prompt="src",
                in_response="dst",
                out_input="input",
                out_target="target",
                out_target_mask="loss_mask",
                tokenizer=tokenizer,
                max_length=512,
                truncate=True,
            ),
        ],
    ),
    model=model,
    init_transform=init_transform,
    num_train_steps=10,
    train_losses={
        "xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits="preds.logits",
            labels="batch.target",
            mask="batch.loss_mask",
        ),
    },
    optimizer=kd.optim.partial_updates(
        optax.adafactor(learning_rate=5e-5),
        mask=kd.optim.select("lora"),
    ),
)

trainer.train()  # Fails here

Error

Training failed: User-provided restore item and on-disk value metadata
tree structures do not match:

layer_9.mlp.linear:
- Target: <class 'dict'>
- Source: <class 'jaxlib._jax.ArrayImpl'>

If this mismatch is intentional, pass `partial_restore=True` to only
restore parameters found in `item`.

Note: partial_restore=True is not a valid parameter for LoadCheckpoint.

Environment

  • TPU: ct6e-standard-1t, topology 1x1 (Vertex AI, europe-west4)
  • Python: 3.12
  • JAX: 0.9.2
  • gemma: git main @ b2dc124 (2026-04-03)
  • dialog: git main @ 44599a4 (fixes Error with model = gm.nn.Gemma4_E2B() #606)
  • kauldron: latest from PyPI
  • orbax-checkpoint: 0.11.33

What works

  • gm.nn.Gemma4_E2B() model creation — OK
  • gm.nn.LoRA(rank=8, model=...) wrapper — OK
  • gm.ckpts.load_params(CheckpointPath.GEMMA4_E2B_IT) base load — OK (19.1 GiB, 65s)
  • gm.text.ChatSampler(model=..., params=...).chat(...) generation — OK
  • Kauldron trainer config construction — OK

Only SkipLoRA + LoadCheckpoint fails during trainer.train() init.

Variants tested (all fail with same error)

Variant Result
Gemma4_E2B(tokens="batch.input") Tree mismatch
Gemma4_E2B(tokens="batch.input", text_only=True) Tree mismatch
Gemma4_E2B() (no tokens arg) Cannot resolve required keys: ['tokens']
LoadCheckpoint(..., partial_restore=True) unexpected keyword argument

Expected behavior

SkipLoRA should map the checkpoint's flat linear params to the LoRA wrapper's split structure (base + lora_a + lora_b), the same way it works for Gemma3_4B in examples/lora.py.

Notes

The examples/lora.py in this repo demonstrates the pattern with Gemma3_4B — it works. The issue is specific to Gemma 4, likely because Gemma 4's MLP layers have a different structure (possibly related to MoE or the new architecture) that SkipLoRA doesn't handle.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions