-
Notifications
You must be signed in to change notification settings - Fork 186
[feat] Add dummy model feature #1009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| { | ||
| "infer_steps": 4, | ||
| "target_fps": 16, | ||
| "video_duration": 8, | ||
| "audio_sr": 16000, | ||
| "target_video_length": 81, | ||
| "resize_mode": "adaptive", | ||
| "self_attn_1_type": "flash_attn3", | ||
| "cross_attn_1_type": "flash_attn3", | ||
| "cross_attn_2_type": "flash_attn3", | ||
| "sample_guide_scale": 1.0, | ||
| "sample_shift": 5, | ||
| "enable_cfg": false, | ||
| "cpu_offload": false, | ||
| "use_31_block": true, | ||
| "dummy_model": true | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -11,7 +11,9 @@ | |||||
|
|
||||||
| import gc | ||||||
| import glob | ||||||
| import json | ||||||
| import os | ||||||
| import struct | ||||||
| from abc import ABC, abstractmethod | ||||||
|
|
||||||
| import torch | ||||||
|
|
@@ -25,6 +27,19 @@ | |||||
| from lightx2v.utils.utils import * | ||||||
| from lightx2v_platform.base.global_var import AI_DEVICE | ||||||
|
|
||||||
| SAFETENSORS_DTYPE_MAP = { | ||||||
| "F64": torch.float64, | ||||||
| "F32": torch.float32, | ||||||
| "F16": torch.float16, | ||||||
| "BF16": torch.bfloat16, | ||||||
| "I64": torch.int64, | ||||||
| "I32": torch.int32, | ||||||
| "I16": torch.int16, | ||||||
| "I8": torch.int8, | ||||||
| "U8": torch.uint8, | ||||||
| "BOOL": torch.bool, | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| class BaseTransformerModel(CompiledMethodsMixin, ABC): | ||||||
| """Base class for all transformer models. | ||||||
|
|
@@ -126,25 +141,93 @@ def _init_infer_class(self): | |||||
| """ | ||||||
| pass | ||||||
|
|
||||||
| @staticmethod | ||||||
| def _read_safetensors_metadata(file_path): | ||||||
| """Read tensor metadata (names, shapes, dtypes) from safetensors file header without loading data.""" | ||||||
| with open(file_path, "rb") as f: | ||||||
| header_size = struct.unpack("<Q", f.read(8))[0] | ||||||
| header_json = f.read(header_size).decode("utf-8") | ||||||
| header = json.loads(header_json) | ||||||
| tensors = {} | ||||||
| for key, info in header.items(): | ||||||
| if key == "__metadata__": | ||||||
| continue | ||||||
| tensors[key] = {"shape": info["shape"], "dtype": info["dtype"]} | ||||||
| return tensors | ||||||
|
|
||||||
| def _load_dummy_ckpt(self, unified_dtype, sensitive_layer): | ||||||
| """Generate random weight dict by reading safetensors metadata, without loading actual pretrained data. | ||||||
|
|
||||||
| When dummy_model is enabled, this reads only the file headers to determine tensor names/shapes, | ||||||
| then allocates random tensors on self.device (derived from AI_DEVICE / cpu_offload). | ||||||
| """ | ||||||
| dummy_device = str(self.device) | ||||||
| logger.info(f"[DummyModel] Generating random weights on device={dummy_device}") | ||||||
|
|
||||||
| if self.config.get("dit_original_ckpt", None): | ||||||
| safetensors_path = self.config["dit_original_ckpt"] | ||||||
| elif self.config.get("dit_quantized_ckpt", None) and self.dit_quantized: | ||||||
| safetensors_path = self.config["dit_quantized_ckpt"] | ||||||
| else: | ||||||
| safetensors_path = self.model_path | ||||||
|
|
||||||
| if os.path.isdir(safetensors_path): | ||||||
| safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) | ||||||
| else: | ||||||
| safetensors_files = [safetensors_path] | ||||||
|
|
||||||
| remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] | ||||||
| preserve_keys = self.preserved_keys if hasattr(self, "preserved_keys") else None | ||||||
|
|
||||||
| weight_dict = {} | ||||||
| for file_path in safetensors_files: | ||||||
| if self.config.get("adapter_model_path", None) is not None: | ||||||
| if self.config["adapter_model_path"] == file_path: | ||||||
| continue | ||||||
| logger.info(f"[DummyModel] Reading metadata from {file_path}") | ||||||
| tensors_meta = self._read_safetensors_metadata(file_path) | ||||||
| for key, meta in tensors_meta.items(): | ||||||
| if any(rk in key for rk in remove_keys): | ||||||
| continue | ||||||
| if preserve_keys is not None and not any(pk in key for pk in preserve_keys): | ||||||
| continue | ||||||
| shape = meta["shape"] | ||||||
| st_dtype_str = meta["dtype"] | ||||||
| if unified_dtype or all(s not in key for s in sensitive_layer): | ||||||
| dtype = GET_DTYPE() | ||||||
| else: | ||||||
| dtype = GET_SENSITIVE_DTYPE() | ||||||
| original_dtype = SAFETENSORS_DTYPE_MAP.get(st_dtype_str) | ||||||
| if original_dtype is not None and not original_dtype.is_floating_point: | ||||||
| dtype = original_dtype | ||||||
| weight_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||
|
|
||||||
| return weight_dict | ||||||
|
|
||||||
| def _init_weights(self, weight_dict=None): | ||||||
| unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE() | ||||||
| # Some layers run with float32 to achieve high accuracy | ||||||
| sensitive_layer = self.sensitive_layer | ||||||
| if weight_dict is None: | ||||||
| is_weight_loader = self._should_load_weights() | ||||||
| if is_weight_loader: | ||||||
| if not self.dit_quantized: | ||||||
| # Load original weights | ||||||
| weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) | ||||||
| else: | ||||||
| # Load quantized weights | ||||||
| weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) | ||||||
| if self.config.get("dummy_model", False): | ||||||
| weight_dict = self._load_dummy_ckpt(unified_dtype, sensitive_layer) | ||||||
| if hasattr(self, "_load_adapter_ckpt"): | ||||||
| weight_dict.update(self._load_adapter_ckpt()) | ||||||
| else: | ||||||
| is_weight_loader = self._should_load_weights() | ||||||
| if is_weight_loader: | ||||||
| if not self.dit_quantized: | ||||||
| # Load original weights | ||||||
| weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) | ||||||
| else: | ||||||
| # Load quantized weights | ||||||
| weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) | ||||||
|
|
||||||
| if (self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False)) or (hasattr(self, "use_tp") and self.use_tp): | ||||||
| weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader) | ||||||
| if (self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False)) or (hasattr(self, "use_tp") and self.use_tp): | ||||||
| weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader) | ||||||
|
|
||||||
| if hasattr(self, "_load_adapter_ckpt"): | ||||||
| weight_dict.update(self._load_adapter_ckpt()) | ||||||
| if hasattr(self, "_load_adapter_ckpt"): | ||||||
| weight_dict.update(self._load_adapter_ckpt()) | ||||||
|
|
||||||
| self.original_weight_dict = weight_dict | ||||||
| else: | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -43,6 +43,26 @@ def _load_adapter_ckpt(self): | |||||
| adapter_model_name = "audio_adapter_model.safetensors" | ||||||
| self.config["adapter_model_path"] = os.path.join(self.config["model_path"], adapter_model_name) | ||||||
|
|
||||||
| if self.config.get("dummy_model", False): | ||||||
| from lightx2v.models.networks.base_model import SAFETENSORS_DTYPE_MAP, BaseTransformerModel | ||||||
|
|
||||||
| dummy_device = str(self.device) | ||||||
| logger.info(f"[DummyModel] Generating random adapter weights on device={dummy_device}") | ||||||
| tensors_meta = BaseTransformerModel._read_safetensors_metadata(self.config["adapter_model_path"]) | ||||||
| adapter_weights_dict = {} | ||||||
| from lightx2v.utils.envs import GET_DTYPE | ||||||
|
|
||||||
| for key, meta in tensors_meta.items(): | ||||||
| if "audio" in key: | ||||||
| continue | ||||||
| shape = meta["shape"] | ||||||
| dtype = GET_DTYPE() | ||||||
| original_dtype = SAFETENSORS_DTYPE_MAP.get(meta["dtype"]) | ||||||
| if original_dtype is not None and not original_dtype.is_floating_point: | ||||||
| dtype = original_dtype | ||||||
| adapter_weights_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the base model, using
Suggested change
|
||||||
| return adapter_weights_dict | ||||||
|
|
||||||
| adapter_offload = self.config.get("cpu_offload", False) | ||||||
| load_from_rank0 = self.config.get("load_from_rank0", False) | ||||||
| adapter_weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
safetensors_pathis a directory but contains no.safetensorsfiles,safetensors_fileswill be an empty list. This will result in an emptyweight_dict, which might cause silent failures or uninitialized parameters later in the pipeline. It's better to check if any files were found and raise an informative error if not.