Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions gemma/gm/ckpts/_paths_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for checkpoint paths."""

import re

from gemma.gm.ckpts import _paths


class TestCheckpointPath:
"""Tests for CheckpointPath enum."""

def test_is_str_enum(self):
"""All enum values should be strings."""
for member in _paths.CheckpointPath:
assert isinstance(member.value, str)

def test_all_paths_start_with_gs(self):
"""All checkpoint paths should be GCS paths."""
for member in _paths.CheckpointPath:
assert member.value.startswith('gs://'), (
f'{member.name} does not start with gs://: {member.value}'
)

def test_all_paths_contain_checkpoints(self):
"""All paths should contain the 'checkpoints' directory."""
for member in _paths.CheckpointPath:
assert '/checkpoints/' in member.value, (
f'{member.name} missing /checkpoints/ in path: {member.value}'
)

def test_path_format_matches_name(self):
"""Path basename should be a lowercased version consistent with the name.

E.g. GEMMA3_4B_IT -> gemma3-4b-it
"""
for member in _paths.CheckpointPath:
basename = member.value.rsplit('/', 1)[-1]
# Convert name: GEMMA3_4B_IT -> gemma3-4b-it
expected = member.name.lower().replace('_', '-')
# Handle gemma3n enum names: GEMMA3N_E2B_IT -> gemma3n-e2b-it
assert basename == expected, (
f'{member.name}: expected basename {expected!r}, got {basename!r}'
)

def test_naming_convention_version_size_variant(self):
"""All names should follow VERSION_SIZE_VARIANT pattern."""
pattern = re.compile(
r'^GEMMA[23N]*_' # Version: GEMMA2, GEMMA3, GEMMA3N
r'(270M|1B|2B|4B|9B|12B|27B|E2B|E4B)_' # Size
r'(PT|IT)$' # Variant: Pre-trained or Instruction-Tuned
)
for member in _paths.CheckpointPath:
assert pattern.match(member.name), (
f'{member.name} does not match VERSION_SIZE_VARIANT pattern'
)

def test_pt_and_it_pairs_exist(self):
"""Every model size should have both PT and IT variants."""
names = [m.name for m in _paths.CheckpointPath]
pt_names = {n.replace('_PT', '') for n in names if n.endswith('_PT')}
it_names = {n.replace('_IT', '') for n in names if n.endswith('_IT')}
assert pt_names == it_names, (
f'PT/IT mismatch. PT-only: {pt_names - it_names}, '
f'IT-only: {it_names - pt_names}'
)

def test_gemma2_models_exist(self):
"""Gemma 2 should have 2B, 9B, 27B sizes."""
expected = {'GEMMA2_2B_PT', 'GEMMA2_9B_PT', 'GEMMA2_27B_PT',
'GEMMA2_2B_IT', 'GEMMA2_9B_IT', 'GEMMA2_27B_IT'}
names = {m.name for m in _paths.CheckpointPath}
assert expected.issubset(names)

def test_gemma3_models_exist(self):
"""Gemma 3 should have 270M, 1B, 4B, 12B, 27B sizes."""
expected = {'GEMMA3_270M_PT', 'GEMMA3_1B_PT', 'GEMMA3_4B_PT',
'GEMMA3_12B_PT', 'GEMMA3_27B_PT',
'GEMMA3_270M_IT', 'GEMMA3_1B_IT', 'GEMMA3_4B_IT',
'GEMMA3_12B_IT', 'GEMMA3_27B_IT'}
names = {m.name for m in _paths.CheckpointPath}
assert expected.issubset(names)

def test_gemma3n_models_exist(self):
"""Gemma 3N should have E2B and E4B sizes."""
expected = {'GEMMA3N_E2B_PT', 'GEMMA3N_E4B_PT',
'GEMMA3N_E2B_IT', 'GEMMA3N_E4B_IT'}
names = {m.name for m in _paths.CheckpointPath}
assert expected.issubset(names)

def test_str_enum_lookup(self):
"""Should be able to use string value to look up the enum."""
path = 'gs://gemma-data/checkpoints/gemma3-4b-it'
member = _paths.CheckpointPath(path)
assert member is _paths.CheckpointPath.GEMMA3_4B_IT

def test_no_duplicate_paths(self):
"""All paths should be unique."""
values = [m.value for m in _paths.CheckpointPath]
assert len(values) == len(set(values)), 'Duplicate checkpoint paths found'

def test_total_count(self):
"""Verify the expected total number of checkpoints."""
# Gemma2: 3 sizes * 2 variants = 6
# Gemma3: 5 sizes * 2 variants = 10
# Gemma3N: 2 sizes * 2 variants = 4
# Total = 20
assert len(_paths.CheckpointPath) == 20
42 changes: 42 additions & 0 deletions gemma/gm/data/_tasks_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for data tasks helpers."""

from gemma.gm.data import _tasks


class TestDecodeBytes:
"""Tests for _decode_bytes helper."""

def test_bytes_decoded_to_str(self):
assert _tasks._decode_bytes(b'hello world') == 'hello world'

def test_str_passthrough(self):
assert _tasks._decode_bytes('already a string') == 'already a string'

def test_utf8_bytes(self):
text = 'café'
assert _tasks._decode_bytes(text.encode('utf-8')) == text

def test_empty_bytes(self):
assert _tasks._decode_bytes(b'') == ''

def test_empty_string(self):
assert _tasks._decode_bytes('') == ''

def test_non_string_passthrough(self):
"""Non-bytes, non-string values should pass through unchanged."""
assert _tasks._decode_bytes(42) == 42
assert _tasks._decode_bytes(None) is None
167 changes: 167 additions & 0 deletions gemma/gm/losses/_dpo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for DPO loss."""

from gemma.gm.losses import _dpo
import jax
import jax.numpy as jnp
import numpy as np


def _make_one_hot_logits(targets, vocab_size):
"""Creates logits that are high for the target tokens."""
one_hot = jax.nn.one_hot(targets, vocab_size)
# Scale so softmax is strongly peaked at the target.
return one_hot * 10.0 - 5.0


class TestGetLogprobsForTarget:
"""Tests for _get_logprobs_for_target helper."""

def test_shape(self):
batch, n, seq_len, vocab = 2, 2, 4, 8
logits = jnp.ones((batch, n, seq_len, vocab))
targets = jnp.zeros((batch, n, seq_len), dtype=jnp.int32)
mask = jnp.ones((batch, n, seq_len), dtype=jnp.bool_)

result = _dpo._get_logprobs_for_target(
logits=logits, targets=targets, sequence_mask=mask,
)
assert result.shape == (batch, n)

def test_masked_tokens_ignored(self):
"""Masked positions should not contribute to log-probs."""
vocab = 4
# B=1, N=1, L=3
logits = jnp.zeros((1, 1, 3, vocab))
targets = jnp.array([[[0, 1, 2]]], dtype=jnp.int32)

mask_all = jnp.ones((1, 1, 3), dtype=jnp.bool_)
mask_partial = jnp.array([[[True, False, False]]])

result_all = _dpo._get_logprobs_for_target(
logits=logits, targets=targets, sequence_mask=mask_all,
)
result_partial = _dpo._get_logprobs_for_target(
logits=logits, targets=targets, sequence_mask=mask_partial,
)
# Partial mask should give a less-negative (higher) value since fewer
# tokens contribute.
assert float(result_partial[0, 0]) > float(result_all[0, 0])

def test_perfect_logits_give_near_zero_logprob(self):
"""When logits strongly favor the target, log-prob should be near 0."""
vocab = 4
targets = jnp.array([[[0, 1]]], dtype=jnp.int32) # B=1, N=1, L=2
logits = _make_one_hot_logits(targets, vocab)
mask = jnp.ones((1, 1, 2), dtype=jnp.bool_)

result = _dpo._get_logprobs_for_target(
logits=logits, targets=targets, sequence_mask=mask,
)
# Sum of log-probs should be close to 0 (each token ~ log(1) = 0).
np.testing.assert_allclose(float(result[0, 0]), 0.0, atol=0.02)


class TestDpoLoss:
"""Tests for DpoLoss.get_values."""

def test_output_shape(self):
batch, n, seq_len, vocab = 2, 2, 4, 8
tokens = jnp.zeros((batch, n, seq_len), dtype=jnp.int32)
mask = jnp.ones((batch, n, seq_len), dtype=jnp.bool_)
logits = jnp.ones((batch, n, seq_len, vocab))

loss = _dpo.DpoLoss()
result = loss.get_values(
tokens=tokens,
sequence_mask=mask,
policy_logits=logits,
anchor_logits=logits,
)
assert result.shape == (batch, 1)

def test_zero_when_policy_equals_anchor(self):
"""When policy == anchor, diff_logprob is 0 for both chosen/rejected.

po_delta = 0, so loss = -(log_sigmoid(0)*(1-ls) + log_sigmoid(0)*ls)
= -log_sigmoid(0) = -log(0.5) = log(2).
"""
batch, n, seq_len, vocab = 1, 2, 3, 4
rng = jax.random.PRNGKey(42)
logits = jax.random.normal(rng, (batch, n, seq_len, vocab))
tokens = jnp.zeros((batch, n, seq_len), dtype=jnp.int32)
mask = jnp.ones((batch, n, seq_len), dtype=jnp.bool_)

loss = _dpo.DpoLoss(tau=0.1, label_smoothing=0.0)
result = loss.get_values(
tokens=tokens,
sequence_mask=mask,
policy_logits=logits,
anchor_logits=logits,
)
# po_delta = 0 => loss = -log_sigmoid(0) = log(2)
np.testing.assert_allclose(
float(result[0, 0]), np.log(2), atol=1e-5
)

def test_loss_is_non_negative(self):
"""DPO loss should always be non-negative."""
batch, n, seq_len, vocab = 3, 2, 5, 8
rng = jax.random.PRNGKey(0)
k1, k2 = jax.random.split(rng)
policy_logits = jax.random.normal(k1, (batch, n, seq_len, vocab))
anchor_logits = jax.random.normal(k2, (batch, n, seq_len, vocab))
tokens = jax.random.randint(k1, (batch, n, seq_len), 0, vocab)
mask = jnp.ones((batch, n, seq_len), dtype=jnp.bool_)

loss = _dpo.DpoLoss(tau=0.1, label_smoothing=0.0)
result = loss.get_values(
tokens=tokens,
sequence_mask=mask,
policy_logits=policy_logits,
anchor_logits=anchor_logits,
)
assert jnp.all(result >= 0.0)

def test_label_smoothing_effect(self):
"""Label smoothing should change the loss value."""
batch, n, seq_len, vocab = 1, 2, 3, 4
rng = jax.random.PRNGKey(7)
k1, k2 = jax.random.split(rng)
policy_logits = jax.random.normal(k1, (batch, n, seq_len, vocab))
anchor_logits = jax.random.normal(k2, (batch, n, seq_len, vocab))
tokens = jnp.zeros((batch, n, seq_len), dtype=jnp.int32)
mask = jnp.ones((batch, n, seq_len), dtype=jnp.bool_)

loss_no_smooth = _dpo.DpoLoss(tau=0.1, label_smoothing=0.0)
loss_smooth = _dpo.DpoLoss(tau=0.1, label_smoothing=0.5)

result_no_smooth = loss_no_smooth.get_values(
tokens=tokens,
sequence_mask=mask,
policy_logits=policy_logits,
anchor_logits=anchor_logits,
)
result_smooth = loss_smooth.get_values(
tokens=tokens,
sequence_mask=mask,
policy_logits=policy_logits,
anchor_logits=anchor_logits,
)
# With label_smoothing=0.5, loss = -log_sigmoid(x)*0.5 - log_sigmoid(-x)*0.5
# = -0.5*(log_sigmoid(x) + log_sigmoid(-x))
# which differs from label_smoothing=0.0 unless po_delta=0.
assert not jnp.allclose(result_no_smooth, result_smooth)
Loading