diff --git a/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py b/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py new file mode 100644 index 00000000..ec74dcda --- /dev/null +++ b/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py @@ -0,0 +1,73 @@ +import torch +import numpy as np +from transformers import AutoTokenizer, AutoModel + + +class HFGemma3TextEncoder: + """ + A lightweight wrapper around Hugging Face's Gemma 3 model for extracting hidden states. + This module forces execution on CPU to avoid OOM or XLA collisions when used alongside + JAX/MaxDiffusion on TPUs. + """ + + def __init__(self, model_id: str = "google/gemma-3-12b-it", max_length: int = 8192): + self.model_id = model_id + self.max_length = max_length + # Initialize the tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + # Load the model directly to CPU in bfloat16 to save memory + print(f"Loading {model_id} onto CPU. This may take a few moments...") + self.model = AutoModel.from_pretrained( + self.model_id, + torch_dtype=torch.bfloat16, + device_map="cpu", # Force CPU to avoid TPU memory contention with MaxDiffusion + ) + self.model.eval() # Set to evaluation mode + + def encode(self, text: str | list[str]) -> np.ndarray: + """ + Tokenizes the input text, passes it through the HF Gemma 3 model, + and extracts ALL hidden states. + + Args: + text: A single string or a list of strings to encode. + + Returns: + A numpy array representing the flattened, stacked hidden states + compatible with GemmaFeaturesExtractorProjLinear. + Shape: (batch_size, sequence_length, 49 * 3840) + """ + # 1. Tokenize input text + inputs = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") + + # Ensure inputs are on the same device as the model (CPU) + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + + # 2. Forward pass to get hidden states + # output_hidden_states=True is the key to retrieving all 49 layers + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True) + + # 3. Extract and stack hidden states + # outputs.hidden_states is a tuple of 49 tensors, each shaped (batch, seq_len, 3840) + all_hidden_states = outputs.hidden_states + + # Stack them along a new dimension (dim=0 or dim=-2) + # We want to format it so it's easy to flatten. + # Stacked shape: (49, batch, seq_len, 3840) + stacked_states = torch.stack(all_hidden_states, dim=0) + + # Transpose to: (batch, seq_len, 49, 3840) + transposed_states = stacked_states.permute(1, 2, 0, 3) + + # Flatten the last two dimensions to match the Feature Extractor's expectation + # Shape becomes: (batch, seq_len, 49 * 3840) -> (batch, seq_len, 188160) + batch_size, seq_len, num_layers, hidden_dim = transposed_states.shape + flattened_states = transposed_states.reshape(batch_size, seq_len, num_layers * hidden_dim) + + # 4. Convert PyTorch Tensor to NumPy Array + # JAX/Flax can seamlessly accept and convert numpy arrays to JAX Arrays + numpy_hidden_states = flattened_states.cpu().float().numpy() + + return numpy_hidden_states diff --git a/src/maxdiffusion/tests/test_hf_gemma3_encoder.py b/src/maxdiffusion/tests/test_hf_gemma3_encoder.py new file mode 100644 index 00000000..13659231 --- /dev/null +++ b/src/maxdiffusion/tests/test_hf_gemma3_encoder.py @@ -0,0 +1,35 @@ +import pytest +import numpy as np + +from maxdiffusion.models.ltx2.text_encoders.hf_gemma3_encoder import HFGemma3TextEncoder + + +class TestHFGemma3TextEncoder: + """Test suite for the Hugging Face CPU-based Gemma 3 Text Encoder.""" + + @pytest.fixture(scope="class") + def encoder(self): + """Initialize the encoder. We use a small max_length to save memory and time.""" + print("Initializing HFGemma3TextEncoder on CPU...") + # Note: Depending on your system memory, loading 12B on CPU might take ~25GB RAM. + # Ensure the test node has enough CPU RAM. + encoder = HFGemma3TextEncoder("google/gemma-3-12b-it", max_length=16) + return encoder + + def test_encode_output_shape(self, encoder): + """Verify that the encode method returns the correctly flattened numpy array.""" + prompt = "A test prompt for HF Gemma 3" + + # Run encode + print("Running encode forward pass on CPU...") + output_array = encoder.encode(prompt) + + # Verify it's a numpy array + assert isinstance(output_array, np.ndarray), "Output must be a numpy array for JAX integration." + + # Verify shape + # Expected: (batch_size, sequence_length, 49 * 3840) -> (1, 16, 188160) + expected_shape = (1, 16, 49 * 3840) + assert output_array.shape == expected_shape, f"Expected shape {expected_shape}, got {output_array.shape}" + + print(f"✅ Output successfully shaped for GemmaFeaturesExtractorProjLinear: {output_array.shape}")