diff --git a/src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py b/src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py new file mode 100644 index 00000000..43374196 --- /dev/null +++ b/src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py @@ -0,0 +1,190 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Optional, Tuple +import jax +import jax.numpy as jnp +from flax import nnx +from maxdiffusion import common_types +from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention +from maxdiffusion.models.attention_flax import NNXSimpleFeedForward + +Array = common_types.Array +DType = common_types.DType + + +class _BasicTransformerBlock1D(nnx.Module): + + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + rope_type: str = "interleaved", + attention_kernel: str = "flash", + mesh: jax.sharding.Mesh = None, + rngs: nnx.Rngs = None, + ): + self.attn1 = LTX2Attention( + query_dim=dim, + heads=heads, + dim_head=dim_head, + rope_type=rope_type, + bias=True, # LTX-2 default + out_bias=True, + attention_kernel=attention_kernel, + mesh=mesh, + rngs=rngs, + ) + self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim) + self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs) + self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs) + + def __call__( + self, + hidden_states: Array, + attention_mask: Optional[Array] = None, + rotary_emb: Optional[Tuple[Array, Array]] = None, + ) -> Array: + # 1. Norm -> Attention + normed = self.norm1(hidden_states) + attn_output = self.attn1(normed, attention_mask=attention_mask, rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_output + + # 2. Norm -> FeedForward + normed = self.norm2(hidden_states) + ff_output = self.ff(normed) + hidden_states = hidden_states + ff_output + + return hidden_states + + +class Embeddings1DConnector(nnx.Module): + """ + Applies 1D transformer processing with Thinking Tokens (Learnable Registers). + Uses nnx.scan for efficient JAX-idiomatic layer execution. + """ + + def __init__( + self, + input_dim: int, + heads: int = 30, + head_dim: int = 128, + layers: int = 2, + theta: float = 10000.0, + num_learnable_registers: int = 128, + rope_type: str = "interleaved", + attention_kernel: str = "flash", + mesh: jax.sharding.Mesh = None, + rngs: nnx.Rngs = None, + ): + self.dim = input_dim + self.theta = theta + self.num_learnable_registers = num_learnable_registers + self.num_layers = layers + + # 1. Initialize Stacked Layers using vmap + # This creates a single module where parameters have an extra leading dimension [layers, ...] + # We need to ensure rngs are split for each layer + @nnx.split_rngs(splits=layers) + @nnx.vmap(in_axes=0, out_axes=0, axis_size=layers) + def create_block(rngs): + return _BasicTransformerBlock1D( + dim=input_dim, + heads=heads, + dim_head=head_dim, + rope_type=rope_type, + attention_kernel=attention_kernel, + mesh=mesh, + rngs=rngs, + ) + + # Call the vmapped constructor + self.stacked_blocks = create_block(rngs) + + # 2. Thinking Tokens + if num_learnable_registers > 0: + key = rngs.params() + self.learnable_registers = nnx.Param( + jax.random.uniform(key, (num_learnable_registers, self.dim), dtype=jnp.bfloat16) * 2.0 - 1.0 + ) + + self.final_norm = nnx.RMSNorm( + self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs + ) + + def _replace_padded_with_learnable_registers(self, hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array]: + b, t, d = hidden_states.shape + if t % self.num_learnable_registers != 0: + raise ValueError(f"Sequence length {t} must be divisible by {self.num_learnable_registers}") + + num_duplications = t // self.num_learnable_registers + registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1)) + registers = jnp.expand_dims(registers, 0) + + if attention_mask.ndim == 2: + mask = attention_mask[:, :, None] + else: + mask = attention_mask + + output = jnp.where(mask > 0.5, hidden_states, registers) + new_mask = jnp.ones_like(attention_mask) + return output, new_mask + + def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]: + t = jnp.arange(seq_len, dtype=jnp.float32) + freqs = 1.0 / (self.theta ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim)) + emb = jnp.outer(t, freqs) + cos = jnp.cos(emb) + sin = jnp.sin(emb) + cos = jnp.repeat(cos, 2, axis=-1) + sin = jnp.repeat(sin, 2, axis=-1) + return cos[None, ...], sin[None, ...] + + def __call__( + self, + hidden_states: Array, + attention_mask: Optional[Array] = None, + ) -> Array: + # 1. Thinking Tokens + if self.num_learnable_registers > 0 and attention_mask is not None: + hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask) + + # 2. RoPE + seq_len = hidden_states.shape[1] + rotary_emb = self._compute_1d_rope(seq_len, hidden_states.dtype) + + # 3. Transformer Blocks (Scan) + + # Scan function signature: (carry, x) -> (carry, y) + def block_scan_fn(carry, block_module): + hidden_states = carry + # block_module is a sliced view of the vmapped module + hidden_states = block_module(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + return hidden_states, None + + # Execute scan + hidden_states, _ = nnx.scan( + block_scan_fn, + length=self.num_layers, + in_axes=(nnx.Carry, 0), # Scan over the layers dimension (0) of block_module + out_axes=(nnx.Carry, 0), + )(hidden_states, self.stacked_blocks) + + # 4. Final Norm + hidden_states = self.final_norm(hidden_states) + + return hidden_states diff --git a/src/maxdiffusion/tests/test_embeddings_connector_ltx2.py b/src/maxdiffusion/tests/test_embeddings_connector_ltx2.py new file mode 100644 index 00000000..3a25d445 --- /dev/null +++ b/src/maxdiffusion/tests/test_embeddings_connector_ltx2.py @@ -0,0 +1,114 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +import jax.numpy as jnp +import numpy as np +from flax import nnx +from ..models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector + + +class Embeddings1DConnectorTest(unittest.TestCase): + + def setUp(self): + self.rng = nnx.Rngs(0) + self.B = 2 + self.T = 16 # Must be divisible by num_learnable_registers if we want tiling to work simply + self.D = 64 # inner_dim + + # Test config + self.num_learnable_registers = 8 + self.heads = 4 + self.head_dim = 16 + + # input dim = heads * head_dim = 64 + + def test_thinking_tokens_replacement(self): + connector = Embeddings1DConnector( + input_dim=self.D, + heads=self.heads, + head_dim=self.head_dim, + layers=1, + num_learnable_registers=self.num_learnable_registers, + mesh=None, + rngs=self.rng, + ) + + # Create input [B, T, D] + hidden_states = jnp.zeros((self.B, self.T, self.D)) + + # Create mask [B, T] + # Batch 0: First 4 valid, rest padding + # Batch 1: First 8 valid, rest padding + mask = np.zeros((self.B, self.T), dtype=np.int32) + mask[0, :4] = 1 + mask[1, :8] = 1 + + # Explicitly run replacement method + output, new_mask = connector._replace_padded_with_learnable_registers(hidden_states, jnp.array(mask)) + + # 1. Check Mask Reset + self.assertTrue(jnp.all(new_mask == 1.0), "New mask should be all 1s") + + # 2. Check Valid Tokens (should be 0 as input was 0) + # Batch 0, 0-3 + valid_b0 = output[0, :4, :] + self.assertTrue(jnp.all(valid_b0 == 0.0), "Valid tokens should remain unchanged") + + # 3. Check Thinking Tokens (Padding area) + # Batch 0, 4-15 + thinking_b0 = output[0, 4:, :] + + # The learnable registers should be tiled. + # Registers shape: [8, 64] + # T=16, so it's tiled 2 times -> [16, 64] + # We need to verify that padding positions contain values from registers + + # Get expected registers values + registers_val = connector.learnable_registers[...] # [8, 64] + tiled_regs = jnp.tile(registers_val, (2, 1)) # [16, 64] + + expected_padding = tiled_regs[4:, :] # corresponding slice + + np.testing.assert_allclose( + thinking_b0, expected_padding, err_msg="Padding should be replaced by corresponding register values" + ) + print("\n[PASS] Thinking Tokens Replacement Logic Verified.") + + def test_forward_shape_and_run(self): + connector = Embeddings1DConnector( + input_dim=self.D, + heads=self.heads, + head_dim=self.head_dim, + layers=2, + num_learnable_registers=self.num_learnable_registers, + attention_kernel="dot_product", # Use dot_product for testing on CPU + mesh=None, + rngs=self.rng, + ) + + hidden_states = jnp.array(np.random.randn(self.B, self.T, self.D)) + mask = jnp.ones((self.B, self.T)) # All valid + + output = connector(hidden_states, mask) + + self.assertEqual(output.shape, (self.B, self.T, self.D)) + self.assertFalse(jnp.isnan(output).any(), "Output should not contain NaNs") + print("\n[PASS] Embeddings1DConnector Forward Pass Verified.") + + +if __name__ == "__main__": + unittest.main()