Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/research.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Here is an example to create a text-only, 12 layers transformer:
class MyTinyTransformer(gm.nn.Transformer):
config: gm.nn.config.TransformerConfig = gm.nn.config.TransformerConfig(
final_logit_softcap=None,
num_embed=262144, # Vocab size, matching the tokenizer
vocab_size=262144, # Vocab size, matching the tokenizer
embed_dim=896,
hidden_dim=4 * 896,
num_heads=4,
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/nn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class QueryPreAttentionNormalisation(enum.Enum):
class TransformerConfig:
"""Configuration for the gemma transformer."""

num_embed: int # TODO(epot): Rename to `vocab_size` for consistency.
vocab_size: int
embed_dim: int
hidden_dim: int
num_heads: int
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/nn/_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Gemma3_500m(_transformer.Transformer): # pylint: disable=invalid-name

config: _config.TransformerConfig = _config.TransformerConfig(
final_logit_softcap=None,
num_embed=262144,
vocab_size=262144,
embed_dim=896,
hidden_dim=4 * 896,
num_heads=4,
Expand Down
16 changes: 8 additions & 8 deletions gemma/gm/nn/_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Gemma2_2B(_transformer.Transformer): # pylint: disable=invalid-name
"""Gemma2 transformer architecture."""

config: _config.TransformerConfig = _config.TransformerConfig(
num_embed=256128,
vocab_size=256128,
embed_dim=2304,
hidden_dim=9216,
num_heads=8,
Expand Down Expand Up @@ -79,7 +79,7 @@ class Gemma2_9B(_transformer.Transformer): # pylint: disable=invalid-name
"""Gemma2 transformer architecture."""

config: _config.TransformerConfig = _config.TransformerConfig(
num_embed=256128,
vocab_size=256128,
embed_dim=3584,
hidden_dim=14336,
num_heads=16,
Expand Down Expand Up @@ -109,7 +109,7 @@ class Gemma2_27B(_transformer.Transformer): # pylint: disable=invalid-name
"""Gemma2 transformer architecture."""

config: _config.TransformerConfig = _config.TransformerConfig(
num_embed=256128,
vocab_size=256128,
embed_dim=4608,
hidden_dim=36864,
num_heads=32,
Expand Down Expand Up @@ -140,7 +140,7 @@ class Gemma3_270M(_transformer.Transformer): # pylint: disable=invalid-name

config: _config.TransformerConfig = _config.TransformerConfig(
final_logit_softcap=None,
num_embed=262144,
vocab_size=262144,
embed_dim=640,
hidden_dim=2048,
num_heads=4,
Expand Down Expand Up @@ -171,7 +171,7 @@ class Gemma3_1B(_transformer.Transformer): # pylint: disable=invalid-name

config: _config.TransformerConfig = _config.TransformerConfig(
final_logit_softcap=None,
num_embed=262144,
vocab_size=262144,
embed_dim=1152,
hidden_dim=6 * 1152,
num_heads=4,
Expand Down Expand Up @@ -223,7 +223,7 @@ class Gemma3_4B(_Gemma3Base): # pylint: disable=invalid-name

config: _config.TransformerConfig = _config.TransformerConfig(
final_logit_softcap=None,
num_embed=262_144,
vocab_size=262_144,
embed_dim=2560,
hidden_dim=2560 * 8 // 2,
num_heads=8,
Expand Down Expand Up @@ -256,7 +256,7 @@ class Gemma3_12B(_Gemma3Base): # pylint: disable=invalid-name

config: _config.TransformerConfig = _config.TransformerConfig(
final_logit_softcap=None,
num_embed=262144,
vocab_size=262144,
embed_dim=30 * 128,
hidden_dim=8 * 30 * 128 // 2,
num_heads=16,
Expand Down Expand Up @@ -288,7 +288,7 @@ class Gemma3_27B(_Gemma3Base): # pylint: disable=invalid-name

config: _config.TransformerConfig = _config.TransformerConfig(
final_logit_softcap=None,
num_embed=262144,
vocab_size=262144,
embed_dim=5376,
hidden_dim=5376 * 8 // 2,
num_heads=32,
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/nn/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __post_init__(self):

def setup(self):
self.embedder = _modules.Embedder(
vocab_size=self.config.num_embed,
vocab_size=self.config.vocab_size,
embed_dim=self.config.embed_dim,
vision_proj_dim=self.config.vision_encoder.siglip_encoder.width
if self.config.vision_encoder
Expand Down
4 changes: 2 additions & 2 deletions gemma/gm/nn/_transformer_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ class TransformerConfig(Protocol):

Attributes:
input_config: Configuration for the model's input.
num_embed: Vocabulary size.
vocab_size: Vocabulary size.
"""

input_config: _types.InputConfig
num_embed: int
vocab_size: int

def init_cache(
self,
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/nn/_transformer_like_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _init_and_apply(

def _get_config() -> gm.nn.config.TransformerConfig:
return gm.nn.config.TransformerConfig(
num_embed=13,
vocab_size=13,
embed_dim=32,
hidden_dim=128,
num_heads=2,
Expand Down
8 changes: 4 additions & 4 deletions gemma/gm/nn/_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_transformer(model_cls: type[gm.nn.Transformer]):
model = model_cls() # pylint: disable=missing-kwoa # pytype: disable=missing-parameter
tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)
out, _ = _get_output(model, tokens=tokens)
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.vocab_size)


def test_images():
Expand All @@ -57,7 +57,7 @@ def test_images():
images = jnp.ones((BATCH_SIZE, NUM_IMAGES, 64, 64, 3), dtype=jnp.uint8)
out, _ = _get_output(model, tokens=tokens, images=images)

assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.vocab_size)


def test_text_only():
Expand All @@ -72,12 +72,12 @@ def test_text_only():

out, params = _get_output(model, tokens=tokens)
assert 'vision_encoder' not in params # Vision params not loaded
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.vocab_size)


def test_last_only():
model = gm.nn.Gemma3_4B(return_last_only=True)
tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)
out, params = _get_output(model, tokens=tokens)
assert 'vision_encoder' in params # Vision by default
assert out.logits.shape == (BATCH_SIZE, model.config.num_embed)
assert out.logits.shape == (BATCH_SIZE, model.config.vocab_size)
2 changes: 1 addition & 1 deletion gemma/gm/nn/gemma3n/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class QueryPreAttentionNormalisation(enum.Enum):
class TransformerConfig:
"""Configuration for the gemma transformer."""

num_embed: int # TODO(epot): Rename to `vocab_size` for consistency.
vocab_size: int
embed_dim: int
hidden_dim: int
num_heads: int
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/nn/gemma3n/_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _get_gemma3n_config(
) -> _config.TransformerConfig:
return _config.TransformerConfig(
final_logit_softcap=None,
num_embed=262_144,
vocab_size=262_144,
embed_dim=2048,
hidden_dim=hidden_dim,
num_heads=8,
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/nn/gemma3n/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __post_init__(self):

def setup(self):
self.embedder = _modules.Embedder(
vocab_size=self.config.num_embed,
vocab_size=self.config.vocab_size,
embed_dim=self.config.embed_dim,
vision_proj_dim=self.config.vision_encoder.siglip_encoder.width
if self.config.vision_encoder
Expand Down
8 changes: 4 additions & 4 deletions gemma/gm/nn/gemma3n/_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_transformer(model_cls: type[gt.Gemma3nTransformer]):
model = model_cls() # pylint: disable=missing-kwoa # pytype: disable=missing-parameter
tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)
out, _ = _get_output(model, tokens=tokens)
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.vocab_size)


def test_images():
Expand All @@ -58,7 +58,7 @@ def test_images():
images = jnp.ones((BATCH_SIZE, NUM_IMAGES, 64, 64, 3), dtype=jnp.uint8)
out, _ = _get_output(model, tokens=tokens, images=images)

assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.vocab_size)


def test_text_only():
Expand All @@ -73,12 +73,12 @@ def test_text_only():

out, params = _get_output(model, tokens=tokens)
assert 'vision_encoder' not in params # Vision params not loaded
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.num_embed)
assert out.logits.shape == (BATCH_SIZE, SEQ_LEN, model.config.vocab_size)


def test_last_only():
model = gemma3n_models.Gemma3n_E4B(return_last_only=True) # pytype: disable=missing-parameter
tokens = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)
out, params = _get_output(model, tokens=tokens)
assert 'vision_encoder' in params # Vision by default
assert out.logits.shape == (BATCH_SIZE, model.config.num_embed)
assert out.logits.shape == (BATCH_SIZE, model.config.vocab_size)
2 changes: 1 addition & 1 deletion gemma/gm/testing/_dummy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DummyGemma(_transformer.Transformer): # pylint: disable=invalid-name
"""Dummy transformer architecture, for testing."""

config: config_lib.TransformerConfig = config_lib.TransformerConfig(
num_embed=13, # Vocab size matching `gm.testing.DummyTokenizer()`
vocab_size=13, # Vocab size matching `gm.testing.DummyTokenizer()`
embed_dim=32,
hidden_dim=128,
num_heads=2,
Expand Down
2 changes: 1 addition & 1 deletion gemma/research/t5gemma/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def make_config(
"""Simplify the config creation."""
return TransformerConfig(
num_layers=num_layers,
num_embed=256128,
vocab_size=256128,
embed_dim=embed_dim,
hidden_dim=hidden_dim,
num_heads=num_heads,
Expand Down
4 changes: 2 additions & 2 deletions gemma/research/t5gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class TransformerConfig:
"""Configuration for the gemma transformer."""

num_layers: int
num_embed: int
vocab_size: int
embed_dim: int
hidden_dim: int
num_heads: int
Expand Down Expand Up @@ -585,7 +585,7 @@ class Transformer(nn.Module):

def setup(self):
self.embedder = Embedder(
vocab_size=self.config.num_embed,
vocab_size=self.config.vocab_size,
embed_dim=self.config.embed_dim
)

Expand Down