feat: Gemma4 text generation support (CORE-30)#13376
feat: Gemma4 text generation support (CORE-30)#13376kijai wants to merge 15 commits intoComfy-Org:masterfrom
Conversation
outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a Gemma4 multimodal text-encoder module with text/vision/audio tokenizers, model variants, and generation wrappers; extends SD text-encoder detection/loading to support Gemma4. Introduces TORCH_HAS_GQA and threads SDPA kwargs ( 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
comfy_extras/nodes_textgen.py (1)
35-49:⚠️ Potential issue | 🔴 CriticalUpdate
TextGenerateLTX2Promptfor the new parameters.This signature change breaks the subclass override below.
TextGenerateLTX2Prompt.execute()still callssuper().execute(...)positionally, so itsthinkingargument now lands invideo. With the defaultFalse, Gemma4 tokenization will treat that as a provided video input and blow up onFalse.movedim(...). The inheritedvideo/audioschema also no longer matches the subclass signature.Suggested fix
class TextGenerateLTX2Prompt(TextGenerate): @@ - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput: if image is None: formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" else: formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" - return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking) + return super().execute( + clip, + formatted_prompt, + max_length, + sampling_mode, + image=image, + video=video, + audio=audio, + thinking=thinking, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@comfy_extras/nodes_textgen.py` around lines 35 - 49, The subclass TextGenerateLTX2Prompt.execute signature and its call to super() must be updated to match the base class ordering and new optional inputs so the thinking boolean doesn't end up in the video parameter; change TextGenerateLTX2Prompt.execute to accept (cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) and call super().execute(...) using keyword arguments (e.g., super().execute(clip=clip, prompt=prompt, max_length=max_length, sampling_mode=sampling_mode, image=image, video=video, audio=audio, thinking=thinking)); also update any input/output schema in TextGenerateLTX2Prompt that referenced video/audio to match the inherited io schema so clip.tokenize is passed the correct types.
🧹 Nitpick comments (1)
comfy/ldm/modules/attention.py (1)
227-230: Scale pre-scaling is correct; consider addingenable_gqasupport for consistency.The math to compensate for the internal
1/sqrt(dim_head)is correct. However, unlikeattention_basic, this function doesn't handleenable_gqa. Ifenable_gqa=Trueis passed here, it's silently ignored, which could lead to incorrect results if this code path is exercised by GQA models in the future.♻️ Optional: Add GQA support for consistency
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: b, _, _, dim_head = query.shape else: b, _, dim_head = query.shape dim_head //= heads + if kwargs.get("enable_gqa", False) and query.shape[-3] != key.shape[-3]: + n_rep = query.shape[-3] // key.shape[-3] + key = key.repeat_interleave(n_rep, dim=-3) + value = value.repeat_interleave(n_rep, dim=-3) + if "scale" in kwargs: # Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head)) query = query * (kwargs["scale"] * dim_head ** 0.5)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@comfy/ldm/modules/attention.py` around lines 227 - 230, The pre-scaling block should also respect GQA settings: mirror the logic used in attention_basic by checking kwargs.get("enable_gqa") and, when True, adjust the scale to account for GQA group size before applying it to query (use kwargs.get("gqa_group_size", 1) to compute the additional sqrt factor together with dim_head so the multiplication query = query * (...) properly cancels the internal 1/sqrt(dim_head) for grouped heads); update the block that currently uses kwargs["scale"] and dim_head to incorporate enable_gqa and gqa_group_size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy/ldm/modules/attention.py`:
- Around line 512-516: The code builds sdpa_extra from kwargs and may pass
enable_gqa to comfy.ops.scaled_dot_product_attention causing a TypeError on
PyTorch <2.5; modify the sdpa_extra construction so it only includes
"enable_gqa" when the runtime PyTorch version check (torch_version_numeric) is
>= (2,5) (otherwise only include "scale"), or alternatively always route through
the existing mismatched-shape fallback logic (the block that handles unequal q/k
shapes) so enable_gqa is never passed to scaled_dot_product_attention on older
PyTorch; update the sdpa_extra creation near the SDP_BATCH_LIMIT branch and
references to kwargs so enable_gqa is gated by torch_version_numeric.
In `@comfy/sd.py`:
- Around line 1297-1298: detect_te_model() now returns TEModel.QWEN35_31B for
weight.shape[0]==5120 but TEModel lacks that member and the Qwen35 loader only
supports up to QWEN35_27B; add a new enum member TEModel.QWEN35_31B and extend
the Qwen35 loader branch (the function handling Qwen35 checkpoints / any loader
switch that currently handles QWEN35_7B/QWEN35_13B/QWEN35_27B) to recognize and
correctly load the 31B layout (map the 5120 projection size to the new enum and
implement the corresponding weight slicing/reshaping/parameter assignment logic
consistent with the 27B case). Ensure references to TEModel and the loader
switch (e.g., detect_te_model(), the Qwen35 loading function/class) are updated
so a 31B checkpoint follows the same loading pattern as other Qwen35 sizes.
In `@comfy/text_encoders/gemma4.py`:
- Around line 1074-1080: The tokenizer currently always builds audio_features
from the incoming audio in the block that creates waveform/sample_rate and calls
_extract_mel_spectrogram, which leads to invalid audio embed dicts for variants
that don't support audio (e.g., Gemma4_31B_Config built without
Gemma4AudioMixin). Update the logic in the method that contains this snippet to
first check the model/variant capability (e.g., presence of Gemma4AudioMixin on
the variant or audio_config on the variant instance prepared by _make_variant)
and only process audio when that capability exists; otherwise either raise a
clear error rejecting audio input for audio-less variants or explicitly set
audio_features=[] and skip emitting audio embeds. Reference the audio processing
symbols audio_features, _extract_mel_spectrogram, _make_variant,
Gemma4_31B_Config and Gemma4AudioMixin when making the guard.
- Around line 1006-1008: The Gemma4_Tokenizer.state_dict currently returns an
empty dict so tokenizer_json is dropped; change state_dict in class
Gemma4_Tokenizer to return the tokenizer payload (e.g. include 'tokenizer_json':
self.tokenizer_json) and add a complementary load_state_dict(self, sd) that
restores self.tokenizer_json (and any derived caches) from
sd.get('tokenizer_json') so CLIP.get_sd()/CLIPSave will persist and later
rebuild the tokenizer from tokenizer_json.
---
Outside diff comments:
In `@comfy_extras/nodes_textgen.py`:
- Around line 35-49: The subclass TextGenerateLTX2Prompt.execute signature and
its call to super() must be updated to match the base class ordering and new
optional inputs so the thinking boolean doesn't end up in the video parameter;
change TextGenerateLTX2Prompt.execute to accept (cls, clip, prompt, max_length,
sampling_mode, image=None, video=None, audio=None, thinking=False) and call
super().execute(...) using keyword arguments (e.g., super().execute(clip=clip,
prompt=prompt, max_length=max_length, sampling_mode=sampling_mode, image=image,
video=video, audio=audio, thinking=thinking)); also update any input/output
schema in TextGenerateLTX2Prompt that referenced video/audio to match the
inherited io schema so clip.tokenize is passed the correct types.
---
Nitpick comments:
In `@comfy/ldm/modules/attention.py`:
- Around line 227-230: The pre-scaling block should also respect GQA settings:
mirror the logic used in attention_basic by checking kwargs.get("enable_gqa")
and, when True, adjust the scale to account for GQA group size before applying
it to query (use kwargs.get("gqa_group_size", 1) to compute the additional sqrt
factor together with dim_head so the multiplication query = query * (...)
properly cancels the internal 1/sqrt(dim_head) for grouped heads); update the
block that currently uses kwargs["scale"] and dim_head to incorporate enable_gqa
and gqa_group_size.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 21b2cc8d-4d7a-48de-9afb-8ac74b566014
📒 Files selected for processing (10)
comfy/ldm/modules/attention.pycomfy/rmsnorm.pycomfy/sd.pycomfy/text_encoders/gemma4.pycomfy/text_encoders/llama.pycomfy/text_encoders/lt.pycomfy/text_encoders/lumina2.pycomfy/text_encoders/qwen35.pycomfy/utils.pycomfy_extras/nodes_textgen.py
💤 Files with no reviewable changes (2)
- comfy/text_encoders/qwen35.py
- comfy/utils.py
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
comfy/text_encoders/gemma4.py (1)
1078-1084:⚠️ Potential issue | 🟠 MajorReject audio on audio-less Gemma4 variants.
Gemma4_31B_Configdisables audio, and_make_variant()skipsGemma4AudioMixinfor that model, but this tokenizer still emits audio embeds wheneveraudiois passed. On the 31B path those embeds have no consumer and will fail later in embed preprocessing.Proposed fix
# Process audio audio_features = [] if audio is not None: + if not getattr(self, "supports_audio", True): + raise ValueError("This Gemma4 variant does not support audio inputs.") waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000 mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate) audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T])# in _make_variant() class Tokenizer(Gemma4Tokenizer): supports_audio = audio🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@comfy/text_encoders/gemma4.py` around lines 1078 - 1084, The tokenizer currently processes audio unconditionally, causing unsupported variants (e.g., Gemma4_31B_Config which omits Gemma4AudioMixin) to produce audio embeds that later fail; update _make_variant to set a per-variant flag and gate audio processing: in _make_variant() define the generated Tokenizer class to include supports_audio = audio (or similar boolean) based on whether Gemma4AudioMixin was included, then modify the audio handling branch in Gemma4Tokenizer._encode/_process (the block that creates audio_features from audio, waveform, sample_rate, _extract_mel_spectrogram) to early-return/ignore or raise if audio is provided but supports_audio is False so audio is rejected for audio-less variants.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy/text_encoders/gemma4.py`:
- Around line 1094-1102: The video subsampling code in gemma4.py uses a
hardcoded default fps=24 which drops most frames; change the logic in the
is_video branch (the block that computes fps = kwargs.get("fps", 24), step and
indices, used by clip.tokenize called from TextGenerate.execute()) so it does
not assume 24 FPS—either use kwargs.get("fps", 1) as the default or only perform
subsampling when an explicit fps is provided; update the step calculation
accordingly (keep step = max(1, round(fps)) if using fps as the target frame
interval) so callers that don’t pass fps (and the node tooltip that expects 1
FPS batches) do not lose frames by default.
- Around line 1075-1076: The method tokenize_with_weights currently writes
request-scoped state to self by setting self.thinking, which can leak across
concurrent requests (see tokenize_with_weights and the later read that expects
this flag); make thinking a local variable instead of storing it on self: remove
the assignment to self.thinking, use a local thinking variable, and pass that
local variable into any internal helper calls or downstream functions that
previously relied on self.thinking (or change those helpers to accept a thinking
parameter) so no shared tokenizer/CLIP.clone() instances see or store mutable
request state on the object.
- Around line 624-628: The current dense one-hot construction (creating one_hot
from clamped pixel_position_ids and matmul with position_embedding_table) causes
a huge temporary allocation; replace it by directly gathering the needed rows
from position_embedding_table using the clamped pixel_position_ids. Concretely,
compute clamped_positions = pixel_position_ids.clamp(min=0), move/reshape
pos_table = comfy.model_management.cast_to_device(self.position_embedding_table,
hidden_states.device, hidden_states.dtype) and then gather embeddings for those
indices (e.g., use advanced indexing or torch.gather/torch.take_along_dim to
select the two position rows per patch) and reduce/sum across the position
dimension to produce position_embeddings — avoiding the one_hot -> matmul path
in methods that use pixel_position_ids, position_embedding_table,
position_embedding_size, pos_table, and position_embeddings.
---
Duplicate comments:
In `@comfy/text_encoders/gemma4.py`:
- Around line 1078-1084: The tokenizer currently processes audio
unconditionally, causing unsupported variants (e.g., Gemma4_31B_Config which
omits Gemma4AudioMixin) to produce audio embeds that later fail; update
_make_variant to set a per-variant flag and gate audio processing: in
_make_variant() define the generated Tokenizer class to include supports_audio =
audio (or similar boolean) based on whether Gemma4AudioMixin was included, then
modify the audio handling branch in Gemma4Tokenizer._encode/_process (the block
that creates audio_features from audio, waveform, sample_rate,
_extract_mel_spectrogram) to early-return/ignore or raise if audio is provided
but supports_audio is False so audio is rejected for audio-less variants.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0f9e93ce-8c6c-49ea-8b69-35d29f45cda7
📒 Files selected for processing (4)
comfy/ldm/modules/attention.pycomfy/sd.pycomfy/text_encoders/gemma4.pycomfy_extras/nodes_textgen.py
🚧 Files skipped from review as they are similar to previous changes (2)
- comfy_extras/nodes_textgen.py
- comfy/ldm/modules/attention.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy/ops.py`:
- Around line 1228-1244: The code uses weight.params.orig_dtype when computing
target_dtype but the QuantizedTensor API uses the _params attribute; update the
access in the embedding branch to use weight._params.orig_dtype (keep the
surrounding logic in the block that checks isinstance(weight, QuantizedTensor)
and uses cast_bias_weight / uncast_bias_weight and out_dtype) so target_dtype is
derived from weight._params.orig_dtype instead of weight.params.orig_dtype.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c89e6919-dbd7-4ca5-99a9-765e495945f3
📒 Files selected for processing (1)
comfy/ops.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy_extras/nodes_textgen.py`:
- Around line 173-178: The prompt selection in execute currently tests only
image presence and uses LTX2_T2V_SYSTEM_PROMPT for text-only even when video is
provided; update the branch in execute so it treats video the same as image when
deciding between LTX2_T2V_SYSTEM_PROMPT and LTX2_I2V_SYSTEM_PROMPT (e.g., check
"if image is None and video is None" or "if not (image or video)" before
choosing LTX2_T2V_SYSTEM_PROMPT), ensuring formatted_prompt includes the visual
token path when either image or video is present and leaving the rest of the
call to super().execute unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 88905f6a-4cb7-4d95-88ae-db01ee13f1c8
📒 Files selected for processing (1)
comfy_extras/nodes_textgen.py
| def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput: | ||
| if image is None: | ||
| formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" | ||
| else: | ||
| formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" | ||
| return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking) | ||
| return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking) |
There was a problem hiding this comment.
Treat video as visual conditioning when selecting the LTX2 system prompt.
Line 174 still chooses the text-only prompt whenever image is None, even if video is present. That means a video-conditioned request gets formatted as plain text-only prompt enhancement, despite Line 178 forwarding the frames downstream.
Minimal fix
- if image is None:
+ if image is None and video is None:At minimum, the branch needs to consider video too.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput: | |
| if image is None: | |
| formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" | |
| else: | |
| formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" | |
| return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking) | |
| return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking) | |
| def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput: | |
| if image is None and video is None: | |
| formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" | |
| else: | |
| formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" | |
| return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@comfy_extras/nodes_textgen.py` around lines 173 - 178, The prompt selection
in execute currently tests only image presence and uses LTX2_T2V_SYSTEM_PROMPT
for text-only even when video is provided; update the branch in execute so it
treats video the same as image when deciding between LTX2_T2V_SYSTEM_PROMPT and
LTX2_I2V_SYSTEM_PROMPT (e.g., check "if image is None and video is None" or "if
not (image or video)" before choosing LTX2_T2V_SYSTEM_PROMPT), ensuring
formatted_prompt includes the visual token path when either image or video is
present and leaving the rest of the call to super().execute unchanged.
|
Thanks for this! @kijai |
Adds support for Gemma4 models: E2B, E4B and 31B
https://huggingface.co/Comfy-Org/Gemma4
This is mostly standalone as it includes new functionality:
This implementation was done by referencing the transformers version, and 100% parity in outputs was reached before any optimizations and ComfyUI specific changes, which are inevitable and do not degrade the quality, just bit different randomness from very minor things. However I left most of these in place for easier debugging when comparing to the reference.
To this goal I also added ways to use torch native sdpa GQA and scale, as I couldn't find a way to do those manually in a manner that doesn't lead to tiny numerical differences, as well as option to our RMS norm to easily disable the fused path for debugging/compatibility reasons.
In this process I also found some small mistakes in my Gemma3 text generation and included fixes to those as well, and unified the image embed scaling method with Gemma4.
Also including fp8 scaled embedding support to mixed precision ops, as the embed_token weights are huge in this model, and seem to work fine in fp8.