Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a "dummy model" mode across several model components, including the Audio Encoder, T5, CLIP, and VAE. This feature enables model initialization with random weights by reading only the metadata from safetensors headers, facilitating testing without loading full checkpoints. Review feedback suggests adding error handling for missing checkpoint files and adjusting the random weight initialization scale to 0.02 to ensure numerical stability during dummy inference.
| if os.path.isdir(safetensors_path): | ||
| safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) | ||
| else: | ||
| safetensors_files = [safetensors_path] |
There was a problem hiding this comment.
If safetensors_path is a directory but contains no .safetensors files, safetensors_files will be an empty list. This will result in an empty weight_dict, which might cause silent failures or uninitialized parameters later in the pipeline. It's better to check if any files were found and raise an informative error if not.
| original_dtype = SAFETENSORS_DTYPE_MAP.get(st_dtype_str) | ||
| if original_dtype is not None and not original_dtype.is_floating_point: | ||
| dtype = original_dtype | ||
| weight_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device) |
There was a problem hiding this comment.
Using torch.randn with default parameters (mean=0, std=1) for dummy weight initialization can lead to numerical instability or activations exploding in deep networks, potentially causing NaNs during inference. For transformer models, a smaller standard deviation (e.g., 0.02) is generally safer and more representative of actual weight distributions.
| weight_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device) | |
| weight_dict[key] = (torch.randn(shape, dtype=dtype, device=dummy_device) * 0.02) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device) |
| original_dtype = SAFETENSORS_DTYPE_MAP.get(meta["dtype"]) | ||
| if original_dtype is not None and not original_dtype.is_floating_point: | ||
| dtype = original_dtype | ||
| adapter_weights_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device) |
There was a problem hiding this comment.
Similar to the base model, using torch.randn with unit variance for dummy weights can cause numerical issues. Scaling by a factor like 0.02 is recommended for more stable dummy initialization.
| adapter_weights_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device) | |
| adapter_weights_dict[key] = (torch.randn(shape, dtype=dtype, device=dummy_device) * 0.02) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device) |
No description provided.