From 45bb0257532537b1177d618a5f3d579abd8a324e Mon Sep 17 00:00:00 2001 From: James Huang Date: Wed, 25 Feb 2026 23:38:52 +0000 Subject: [PATCH 1/2] feat: Add HF-based Gemma 3 text encoder for LTX-2 CPU feature extraction --- .../ltx2/text_encoders/hf_gemma3_encoder.py | 77 +++++++++++++++++++ .../tests/test_hf_gemma3_encoder.py | 34 ++++++++ 2 files changed, 111 insertions(+) create mode 100644 src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py create mode 100644 src/maxdiffusion/tests/test_hf_gemma3_encoder.py diff --git a/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py b/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py new file mode 100644 index 00000000..94480837 --- /dev/null +++ b/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py @@ -0,0 +1,77 @@ +import torch +import numpy as np +from transformers import AutoTokenizer, AutoModel + +class HFGemma3TextEncoder: + """ + A lightweight wrapper around Hugging Face's Gemma 3 model for extracting hidden states. + This module forces execution on CPU to avoid OOM or XLA collisions when used alongside + JAX/MaxDiffusion on TPUs. + """ + def __init__(self, model_id: str = "google/gemma-3-12b-it", max_length: int = 8192): + self.model_id = model_id + self.max_length = max_length + # Initialize the tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + # Load the model directly to CPU in bfloat16 to save memory + print(f"Loading {model_id} onto CPU. This may take a few moments...") + self.model = AutoModel.from_pretrained( + self.model_id, + torch_dtype=torch.bfloat16, + device_map="cpu", # Force CPU to avoid TPU memory contention with MaxDiffusion + ) + self.model.eval() # Set to evaluation mode + + def encode(self, text: str | list[str]) -> np.ndarray: + """ + Tokenizes the input text, passes it through the HF Gemma 3 model, + and extracts ALL hidden states. + + Args: + text: A single string or a list of strings to encode. + + Returns: + A numpy array representing the flattened, stacked hidden states + compatible with GemmaFeaturesExtractorProjLinear. + Shape: (batch_size, sequence_length, 49 * 3840) + """ + # 1. Tokenize input text + inputs = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.max_length, + return_tensors="pt" + ) + + # Ensure inputs are on the same device as the model (CPU) + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + + # 2. Forward pass to get hidden states + # output_hidden_states=True is the key to retrieving all 49 layers + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True) + + # 3. Extract and stack hidden states + # outputs.hidden_states is a tuple of 49 tensors, each shaped (batch, seq_len, 3840) + all_hidden_states = outputs.hidden_states + + # Stack them along a new dimension (dim=0 or dim=-2) + # We want to format it so it's easy to flatten. + # Stacked shape: (49, batch, seq_len, 3840) + stacked_states = torch.stack(all_hidden_states, dim=0) + + # Transpose to: (batch, seq_len, 49, 3840) + transposed_states = stacked_states.permute(1, 2, 0, 3) + + # Flatten the last two dimensions to match the Feature Extractor's expectation + # Shape becomes: (batch, seq_len, 49 * 3840) -> (batch, seq_len, 188160) + batch_size, seq_len, num_layers, hidden_dim = transposed_states.shape + flattened_states = transposed_states.reshape(batch_size, seq_len, num_layers * hidden_dim) + + # 4. Convert PyTorch Tensor to NumPy Array + # JAX/Flax can seamlessly accept and convert numpy arrays to JAX Arrays + numpy_hidden_states = flattened_states.cpu().float().numpy() + + return numpy_hidden_states diff --git a/src/maxdiffusion/tests/test_hf_gemma3_encoder.py b/src/maxdiffusion/tests/test_hf_gemma3_encoder.py new file mode 100644 index 00000000..68e2d60c --- /dev/null +++ b/src/maxdiffusion/tests/test_hf_gemma3_encoder.py @@ -0,0 +1,34 @@ +import pytest +import numpy as np + +from maxdiffusion.models.ltx2.text_encoders.hf_gemma3_encoder import HFGemma3TextEncoder + +class TestHFGemma3TextEncoder: + """Test suite for the Hugging Face CPU-based Gemma 3 Text Encoder.""" + + @pytest.fixture(scope="class") + def encoder(self): + """Initialize the encoder. We use a small max_length to save memory and time.""" + print("Initializing HFGemma3TextEncoder on CPU...") + # Note: Depending on your system memory, loading 12B on CPU might take ~25GB RAM. + # Ensure the test node has enough CPU RAM. + encoder = HFGemma3TextEncoder("google/gemma-3-12b-it", max_length=16) + return encoder + + def test_encode_output_shape(self, encoder): + """Verify that the encode method returns the correctly flattened numpy array.""" + prompt = "A test prompt for HF Gemma 3" + + # Run encode + print("Running encode forward pass on CPU...") + output_array = encoder.encode(prompt) + + # Verify it's a numpy array + assert isinstance(output_array, np.ndarray), "Output must be a numpy array for JAX integration." + + # Verify shape + # Expected: (batch_size, sequence_length, 49 * 3840) -> (1, 16, 188160) + expected_shape = (1, 16, 49 * 3840) + assert output_array.shape == expected_shape, f"Expected shape {expected_shape}, got {output_array.shape}" + + print(f"✅ Output successfully shaped for GemmaFeaturesExtractorProjLinear: {output_array.shape}") From b0817a98d7deb498e7b8996438e98f59fb5bcd95 Mon Sep 17 00:00:00 2001 From: James Huang Date: Thu, 26 Feb 2026 01:33:03 +0000 Subject: [PATCH 2/2] ci fix Signed-off-by: James Huang --- .../ltx2/text_encoders/hf_gemma3_encoder.py | 126 +++++++++--------- .../tests/test_hf_gemma3_encoder.py | 57 ++++---- 2 files changed, 90 insertions(+), 93 deletions(-) diff --git a/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py b/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py index 94480837..ec74dcda 100644 --- a/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py +++ b/src/maxdiffusion/models/ltx2/text_encoders/hf_gemma3_encoder.py @@ -2,76 +2,72 @@ import numpy as np from transformers import AutoTokenizer, AutoModel + class HFGemma3TextEncoder: + """ + A lightweight wrapper around Hugging Face's Gemma 3 model for extracting hidden states. + This module forces execution on CPU to avoid OOM or XLA collisions when used alongside + JAX/MaxDiffusion on TPUs. + """ + + def __init__(self, model_id: str = "google/gemma-3-12b-it", max_length: int = 8192): + self.model_id = model_id + self.max_length = max_length + # Initialize the tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + # Load the model directly to CPU in bfloat16 to save memory + print(f"Loading {model_id} onto CPU. This may take a few moments...") + self.model = AutoModel.from_pretrained( + self.model_id, + torch_dtype=torch.bfloat16, + device_map="cpu", # Force CPU to avoid TPU memory contention with MaxDiffusion + ) + self.model.eval() # Set to evaluation mode + + def encode(self, text: str | list[str]) -> np.ndarray: """ - A lightweight wrapper around Hugging Face's Gemma 3 model for extracting hidden states. - This module forces execution on CPU to avoid OOM or XLA collisions when used alongside - JAX/MaxDiffusion on TPUs. + Tokenizes the input text, passes it through the HF Gemma 3 model, + and extracts ALL hidden states. + + Args: + text: A single string or a list of strings to encode. + + Returns: + A numpy array representing the flattened, stacked hidden states + compatible with GemmaFeaturesExtractorProjLinear. + Shape: (batch_size, sequence_length, 49 * 3840) """ - def __init__(self, model_id: str = "google/gemma-3-12b-it", max_length: int = 8192): - self.model_id = model_id - self.max_length = max_length - # Initialize the tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - - # Load the model directly to CPU in bfloat16 to save memory - print(f"Loading {model_id} onto CPU. This may take a few moments...") - self.model = AutoModel.from_pretrained( - self.model_id, - torch_dtype=torch.bfloat16, - device_map="cpu", # Force CPU to avoid TPU memory contention with MaxDiffusion - ) - self.model.eval() # Set to evaluation mode + # 1. Tokenize input text + inputs = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") + + # Ensure inputs are on the same device as the model (CPU) + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + + # 2. Forward pass to get hidden states + # output_hidden_states=True is the key to retrieving all 49 layers + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True) + + # 3. Extract and stack hidden states + # outputs.hidden_states is a tuple of 49 tensors, each shaped (batch, seq_len, 3840) + all_hidden_states = outputs.hidden_states - def encode(self, text: str | list[str]) -> np.ndarray: - """ - Tokenizes the input text, passes it through the HF Gemma 3 model, - and extracts ALL hidden states. - - Args: - text: A single string or a list of strings to encode. - - Returns: - A numpy array representing the flattened, stacked hidden states - compatible with GemmaFeaturesExtractorProjLinear. - Shape: (batch_size, sequence_length, 49 * 3840) - """ - # 1. Tokenize input text - inputs = self.tokenizer( - text, - padding="max_length", - truncation=True, - max_length=self.max_length, - return_tensors="pt" - ) - - # Ensure inputs are on the same device as the model (CPU) - inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + # Stack them along a new dimension (dim=0 or dim=-2) + # We want to format it so it's easy to flatten. + # Stacked shape: (49, batch, seq_len, 3840) + stacked_states = torch.stack(all_hidden_states, dim=0) - # 2. Forward pass to get hidden states - # output_hidden_states=True is the key to retrieving all 49 layers - with torch.no_grad(): - outputs = self.model(**inputs, output_hidden_states=True) + # Transpose to: (batch, seq_len, 49, 3840) + transposed_states = stacked_states.permute(1, 2, 0, 3) - # 3. Extract and stack hidden states - # outputs.hidden_states is a tuple of 49 tensors, each shaped (batch, seq_len, 3840) - all_hidden_states = outputs.hidden_states - - # Stack them along a new dimension (dim=0 or dim=-2) - # We want to format it so it's easy to flatten. - # Stacked shape: (49, batch, seq_len, 3840) - stacked_states = torch.stack(all_hidden_states, dim=0) - - # Transpose to: (batch, seq_len, 49, 3840) - transposed_states = stacked_states.permute(1, 2, 0, 3) - - # Flatten the last two dimensions to match the Feature Extractor's expectation - # Shape becomes: (batch, seq_len, 49 * 3840) -> (batch, seq_len, 188160) - batch_size, seq_len, num_layers, hidden_dim = transposed_states.shape - flattened_states = transposed_states.reshape(batch_size, seq_len, num_layers * hidden_dim) + # Flatten the last two dimensions to match the Feature Extractor's expectation + # Shape becomes: (batch, seq_len, 49 * 3840) -> (batch, seq_len, 188160) + batch_size, seq_len, num_layers, hidden_dim = transposed_states.shape + flattened_states = transposed_states.reshape(batch_size, seq_len, num_layers * hidden_dim) - # 4. Convert PyTorch Tensor to NumPy Array - # JAX/Flax can seamlessly accept and convert numpy arrays to JAX Arrays - numpy_hidden_states = flattened_states.cpu().float().numpy() + # 4. Convert PyTorch Tensor to NumPy Array + # JAX/Flax can seamlessly accept and convert numpy arrays to JAX Arrays + numpy_hidden_states = flattened_states.cpu().float().numpy() - return numpy_hidden_states + return numpy_hidden_states diff --git a/src/maxdiffusion/tests/test_hf_gemma3_encoder.py b/src/maxdiffusion/tests/test_hf_gemma3_encoder.py index 68e2d60c..13659231 100644 --- a/src/maxdiffusion/tests/test_hf_gemma3_encoder.py +++ b/src/maxdiffusion/tests/test_hf_gemma3_encoder.py @@ -3,32 +3,33 @@ from maxdiffusion.models.ltx2.text_encoders.hf_gemma3_encoder import HFGemma3TextEncoder + class TestHFGemma3TextEncoder: - """Test suite for the Hugging Face CPU-based Gemma 3 Text Encoder.""" - - @pytest.fixture(scope="class") - def encoder(self): - """Initialize the encoder. We use a small max_length to save memory and time.""" - print("Initializing HFGemma3TextEncoder on CPU...") - # Note: Depending on your system memory, loading 12B on CPU might take ~25GB RAM. - # Ensure the test node has enough CPU RAM. - encoder = HFGemma3TextEncoder("google/gemma-3-12b-it", max_length=16) - return encoder - - def test_encode_output_shape(self, encoder): - """Verify that the encode method returns the correctly flattened numpy array.""" - prompt = "A test prompt for HF Gemma 3" - - # Run encode - print("Running encode forward pass on CPU...") - output_array = encoder.encode(prompt) - - # Verify it's a numpy array - assert isinstance(output_array, np.ndarray), "Output must be a numpy array for JAX integration." - - # Verify shape - # Expected: (batch_size, sequence_length, 49 * 3840) -> (1, 16, 188160) - expected_shape = (1, 16, 49 * 3840) - assert output_array.shape == expected_shape, f"Expected shape {expected_shape}, got {output_array.shape}" - - print(f"✅ Output successfully shaped for GemmaFeaturesExtractorProjLinear: {output_array.shape}") + """Test suite for the Hugging Face CPU-based Gemma 3 Text Encoder.""" + + @pytest.fixture(scope="class") + def encoder(self): + """Initialize the encoder. We use a small max_length to save memory and time.""" + print("Initializing HFGemma3TextEncoder on CPU...") + # Note: Depending on your system memory, loading 12B on CPU might take ~25GB RAM. + # Ensure the test node has enough CPU RAM. + encoder = HFGemma3TextEncoder("google/gemma-3-12b-it", max_length=16) + return encoder + + def test_encode_output_shape(self, encoder): + """Verify that the encode method returns the correctly flattened numpy array.""" + prompt = "A test prompt for HF Gemma 3" + + # Run encode + print("Running encode forward pass on CPU...") + output_array = encoder.encode(prompt) + + # Verify it's a numpy array + assert isinstance(output_array, np.ndarray), "Output must be a numpy array for JAX integration." + + # Verify shape + # Expected: (batch_size, sequence_length, 49 * 3840) -> (1, 16, 188160) + expected_shape = (1, 16, 49 * 3840) + assert output_array.shape == expected_shape, f"Expected shape {expected_shape}, got {output_array.shape}" + + print(f"✅ Output successfully shaped for GemmaFeaturesExtractorProjLinear: {output_array.shape}")