diff --git a/tests/generate/utils_test.py b/tests/generate/utils_test.py index 682eafb3b..2d19d3e86 100644 --- a/tests/generate/utils_test.py +++ b/tests/generate/utils_test.py @@ -1090,5 +1090,100 @@ def test_transfer_state_directly_implicit_layers_container(self): jnp.array(200.0), ) + def test_transfer_state_directly_with_dtype_casting(self): + """Tests that transfer_state_directly correctly casts dtypes (e.g., f32 to bf16).""" + # Source state in float32 + src_state = nnx.Dict( + decoder=nnx.Dict( + layer0=nnx.Dict( + weight=nnx.Param(jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)) + ), + # Scanned layers in float32 + layers=nnx.Dict( + mlp=nnx.Dict( + weight=nnx.Param(jnp.array([[10.0, 11.0], [20.0, 21.0]], dtype=jnp.float32)) + ) + ) + ) + ) + + # Destination state in bfloat16 + dst_state = nnx.Dict( + decoder=nnx.Dict( + layer0=nnx.Dict( + weight=nnx.Param(jnp.zeros((3,), dtype=jnp.bfloat16)) + ), + # Unrolled layers in bfloat16 + layers_0=nnx.Dict( + mlp=nnx.Dict(weight=nnx.Param(jnp.zeros((2,), dtype=jnp.bfloat16))) + ), + layers_1=nnx.Dict( + mlp=nnx.Dict(weight=nnx.Param(jnp.zeros((2,), dtype=jnp.bfloat16))) + ) + ) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly(src_state, dst_state, reshard_fn=mock_reshard) + + # Verify direct mapping cast + self.assertEqual(dst_state['decoder']['layer0']['weight'].dtype, jnp.bfloat16) + np.testing.assert_allclose( + dst_state['decoder']['layer0']['weight'][...], + jnp.array([1.0, 2.0, 3.0], dtype=jnp.bfloat16), + atol=1e-2 + ) + + # Verify scanned layer mapping cast + self.assertEqual(dst_state['decoder']['layers_0']['mlp']['weight'].dtype, jnp.bfloat16) + np.testing.assert_allclose( + dst_state['decoder']['layers_0']['mlp']['weight'][...], + jnp.array([10.0, 11.0], dtype=jnp.bfloat16), + atol=1e-2 + ) + self.assertEqual(dst_state['decoder']['layers_1']['mlp']['weight'].dtype, jnp.bfloat16) + np.testing.assert_allclose( + dst_state['decoder']['layers_1']['mlp']['weight'][...], + jnp.array([20.0, 21.0], dtype=jnp.bfloat16), + atol=1e-2 + ) + + def test_transfer_state_directly_scanned_layers_casting(self): + """Tests transfer from scanned layers container with dtype casting.""" + # Source has scanned layers in float32 + src_state = nnx.Dict( + layers=nnx.Dict( + mlp=nnx.Dict( + weight=nnx.Param(jnp.array([100.0, 200.0], dtype=jnp.float32)) + ) + ) + ) + + # Destination has unrolled layers_X in bfloat16 + dst_state = nnx.Dict( + layers=nnx.Dict( + layers_0=nnx.Dict(mlp=nnx.Dict(weight=nnx.Param(jnp.zeros((), dtype=jnp.bfloat16)))), + layers_1=nnx.Dict(mlp=nnx.Dict(weight=nnx.Param(jnp.zeros((), dtype=jnp.bfloat16)))), + ) + ) + + mock_reshard = lambda source, target: source + utils.transfer_state_directly(src_state, dst_state, reshard_fn=mock_reshard) + + # Verify casting and slicing for implicit layers + self.assertEqual(dst_state['layers']['layers_0']['mlp']['weight'].dtype, jnp.bfloat16) + np.testing.assert_allclose( + dst_state['layers']['layers_0']['mlp']['weight'][...], + jnp.array(100.0, dtype=jnp.bfloat16), + atol=1e-2 + ) + self.assertEqual(dst_state['layers']['layers_1']['mlp']['weight'].dtype, jnp.bfloat16) + np.testing.assert_allclose( + dst_state['layers']['layers_1']['mlp']['weight'][...], + jnp.array(200.0, dtype=jnp.bfloat16), + atol=1e-2 + ) + + if __name__ == "__main__": absltest.main() diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py index 7b9ce6578..cbadbcf39 100644 --- a/tunix/generate/utils.py +++ b/tunix/generate/utils.py @@ -915,7 +915,11 @@ def intersect_trees( for key_tuple, tgt_val in tgt_flat.items(): # Try Direct Match if key_tuple in src_flat: - filtered_src_flat[key_tuple] = src_flat[key_tuple] + src_val = src_flat[key_tuple] + if hasattr(src_val, 'dtype'): + src_val = _apply_dtype_cast(src_val, tgt_val.dtype, str(key_tuple)) + + filtered_src_flat[key_tuple] = src_val filtered_tgt_flat[key_tuple] = tgt_val continue @@ -954,9 +958,16 @@ def intersect_trees( if found_candidate: src_val = src_flat[found_candidate] - filtered_src_flat[key_tuple] = _slice_scanned_param( + # Slice the scanned parameter + sliced_val = _slice_scanned_param( src_val, tgt_val, layer_idx, str(key_tuple) ) + if hasattr(sliced_val, 'dtype'): + sliced_val = _apply_dtype_cast( + sliced_val, tgt_val.dtype, str(key_tuple) + ) + + filtered_src_flat[key_tuple] = sliced_val filtered_tgt_flat[key_tuple] = tgt_val continue