Skip to content

[feat] Add dummy model feature#1009

Merged
helloyongyang merged 2 commits intomainfrom
dev/dummy_model
Apr 14, 2026
Merged

[feat] Add dummy model feature#1009
helloyongyang merged 2 commits intomainfrom
dev/dummy_model

Conversation

@wangshankun
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a "dummy model" mode across several model components, including the Audio Encoder, T5, CLIP, and VAE. This feature enables model initialization with random weights by reading only the metadata from safetensors headers, facilitating testing without loading full checkpoints. Review feedback suggests adding error handling for missing checkpoint files and adjusting the random weight initialization scale to 0.02 to ensure numerical stability during dummy inference.

Comment on lines +174 to +177
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If safetensors_path is a directory but contains no .safetensors files, safetensors_files will be an empty list. This will result in an empty weight_dict, which might cause silent failures or uninitialized parameters later in the pipeline. It's better to check if any files were found and raise an informative error if not.

original_dtype = SAFETENSORS_DTYPE_MAP.get(st_dtype_str)
if original_dtype is not None and not original_dtype.is_floating_point:
dtype = original_dtype
weight_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.randn with default parameters (mean=0, std=1) for dummy weight initialization can lead to numerical instability or activations exploding in deep networks, potentially causing NaNs during inference. For transformer models, a smaller standard deviation (e.g., 0.02) is generally safer and more representative of actual weight distributions.

Suggested change
weight_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)
weight_dict[key] = (torch.randn(shape, dtype=dtype, device=dummy_device) * 0.02) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)

original_dtype = SAFETENSORS_DTYPE_MAP.get(meta["dtype"])
if original_dtype is not None and not original_dtype.is_floating_point:
dtype = original_dtype
adapter_weights_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the base model, using torch.randn with unit variance for dummy weights can cause numerical issues. Scaling by a factor like 0.02 is recommended for more stable dummy initialization.

Suggested change
adapter_weights_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)
adapter_weights_dict[key] = (torch.randn(shape, dtype=dtype, device=dummy_device) * 0.02) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)

@helloyongyang helloyongyang merged commit 6db002f into main Apr 14, 2026
2 checks passed
@helloyongyang helloyongyang deleted the dev/dummy_model branch April 14, 2026 09:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants