Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tests/generate/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,5 +1090,100 @@ def test_transfer_state_directly_implicit_layers_container(self):
jnp.array(200.0),
)

def test_transfer_state_directly_with_dtype_casting(self):
"""Tests that transfer_state_directly correctly casts dtypes (e.g., f32 to bf16)."""
# Source state in float32
src_state = nnx.Dict(
decoder=nnx.Dict(
layer0=nnx.Dict(
weight=nnx.Param(jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
),
# Scanned layers in float32
layers=nnx.Dict(
mlp=nnx.Dict(
weight=nnx.Param(jnp.array([[10.0, 11.0], [20.0, 21.0]], dtype=jnp.float32))
)
)
)
)

# Destination state in bfloat16
dst_state = nnx.Dict(
decoder=nnx.Dict(
layer0=nnx.Dict(
weight=nnx.Param(jnp.zeros((3,), dtype=jnp.bfloat16))
),
# Unrolled layers in bfloat16
layers_0=nnx.Dict(
mlp=nnx.Dict(weight=nnx.Param(jnp.zeros((2,), dtype=jnp.bfloat16)))
),
layers_1=nnx.Dict(
mlp=nnx.Dict(weight=nnx.Param(jnp.zeros((2,), dtype=jnp.bfloat16)))
)
)
)

mock_reshard = lambda source, target: source
utils.transfer_state_directly(src_state, dst_state, reshard_fn=mock_reshard)

# Verify direct mapping cast
self.assertEqual(dst_state['decoder']['layer0']['weight'].dtype, jnp.bfloat16)
np.testing.assert_allclose(
dst_state['decoder']['layer0']['weight'][...],
jnp.array([1.0, 2.0, 3.0], dtype=jnp.bfloat16),
atol=1e-2
)

# Verify scanned layer mapping cast
self.assertEqual(dst_state['decoder']['layers_0']['mlp']['weight'].dtype, jnp.bfloat16)
np.testing.assert_allclose(
dst_state['decoder']['layers_0']['mlp']['weight'][...],
jnp.array([10.0, 11.0], dtype=jnp.bfloat16),
atol=1e-2
)
self.assertEqual(dst_state['decoder']['layers_1']['mlp']['weight'].dtype, jnp.bfloat16)
np.testing.assert_allclose(
dst_state['decoder']['layers_1']['mlp']['weight'][...],
jnp.array([20.0, 21.0], dtype=jnp.bfloat16),
atol=1e-2
)

def test_transfer_state_directly_scanned_layers_casting(self):
"""Tests transfer from scanned layers container with dtype casting."""
# Source has scanned layers in float32
src_state = nnx.Dict(
layers=nnx.Dict(
mlp=nnx.Dict(
weight=nnx.Param(jnp.array([100.0, 200.0], dtype=jnp.float32))
)
)
)

# Destination has unrolled layers_X in bfloat16
dst_state = nnx.Dict(
layers=nnx.Dict(
layers_0=nnx.Dict(mlp=nnx.Dict(weight=nnx.Param(jnp.zeros((), dtype=jnp.bfloat16)))),
layers_1=nnx.Dict(mlp=nnx.Dict(weight=nnx.Param(jnp.zeros((), dtype=jnp.bfloat16)))),
)
)

mock_reshard = lambda source, target: source
utils.transfer_state_directly(src_state, dst_state, reshard_fn=mock_reshard)

# Verify casting and slicing for implicit layers
self.assertEqual(dst_state['layers']['layers_0']['mlp']['weight'].dtype, jnp.bfloat16)
np.testing.assert_allclose(
dst_state['layers']['layers_0']['mlp']['weight'][...],
jnp.array(100.0, dtype=jnp.bfloat16),
atol=1e-2
)
self.assertEqual(dst_state['layers']['layers_1']['mlp']['weight'].dtype, jnp.bfloat16)
np.testing.assert_allclose(
dst_state['layers']['layers_1']['mlp']['weight'][...],
jnp.array(200.0, dtype=jnp.bfloat16),
atol=1e-2
)


if __name__ == "__main__":
absltest.main()
15 changes: 13 additions & 2 deletions tunix/generate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,11 @@ def intersect_trees(
for key_tuple, tgt_val in tgt_flat.items():
# Try Direct Match
if key_tuple in src_flat:
filtered_src_flat[key_tuple] = src_flat[key_tuple]
src_val = src_flat[key_tuple]
if hasattr(src_val, 'dtype'):
src_val = _apply_dtype_cast(src_val, tgt_val.dtype, str(key_tuple))

filtered_src_flat[key_tuple] = src_val
filtered_tgt_flat[key_tuple] = tgt_val
continue

Expand Down Expand Up @@ -954,9 +958,16 @@ def intersect_trees(

if found_candidate:
src_val = src_flat[found_candidate]
filtered_src_flat[key_tuple] = _slice_scanned_param(
# Slice the scanned parameter
sliced_val = _slice_scanned_param(
src_val, tgt_val, layer_idx, str(key_tuple)
)
if hasattr(sliced_val, 'dtype'):
sliced_val = _apply_dtype_cast(
sliced_val, tgt_val.dtype, str(key_tuple)
)

filtered_src_flat[key_tuple] = sliced_val
filtered_tgt_flat[key_tuple] = tgt_val
continue

Expand Down