diff --git a/configs/seko_talk/seko_talk_01_base_dummy.json b/configs/seko_talk/seko_talk_01_base_dummy.json new file mode 100644 index 000000000..994434c2d --- /dev/null +++ b/configs/seko_talk/seko_talk_01_base_dummy.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 8, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": true, + "dummy_model": true +} diff --git a/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py b/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py index 5b727af44..4f2692a23 100755 --- a/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py +++ b/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py @@ -1,15 +1,17 @@ import torch -from transformers import AutoFeatureExtractor, AutoModel +from loguru import logger +from transformers import AutoConfig, AutoFeatureExtractor, AutoModel from lightx2v.utils.envs import * from lightx2v_platform.base.global_var import AI_DEVICE class SekoAudioEncoderModel: - def __init__(self, model_path, audio_sr, cpu_offload): + def __init__(self, model_path, audio_sr, cpu_offload, dummy_model=False): self.model_path = model_path self.audio_sr = audio_sr self.cpu_offload = cpu_offload + self.dummy_model = dummy_model if self.cpu_offload: self.device = torch.device("cpu") else: @@ -18,7 +20,12 @@ def __init__(self, model_path, audio_sr, cpu_offload): def load(self): self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_path) - self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path) + if self.dummy_model: + logger.info("[DummyModel] Skipping audio encoder weight loading, using random init from config") + config = AutoConfig.from_pretrained(self.model_path) + self.audio_feature_encoder = AutoModel.from_config(config) + else: + self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path) self.audio_feature_encoder.to(self.device) self.audio_feature_encoder.eval() self.audio_feature_encoder.to(GET_DTYPE()) diff --git a/lightx2v/models/input_encoders/hf/wan/t5/model.py b/lightx2v/models/input_encoders/hf/wan/t5/model.py index 17adb176f..88fb3b453 100755 --- a/lightx2v/models/input_encoders/hf/wan/t5/model.py +++ b/lightx2v/models/input_encoders/hf/wan/t5/model.py @@ -792,6 +792,7 @@ def __init__( quant_scheme=None, lazy_load=False, load_from_rank0=False, + dummy_model=False, ): self.text_len = text_len self.dtype = dtype @@ -821,24 +822,28 @@ def __init__( .requires_grad_(False) ) - weights_dict = load_weights( - self.checkpoint_path, - cpu_offload=cpu_offload, - load_from_rank0=load_from_rank0, - ) + if not dummy_model: + weights_dict = load_weights( + self.checkpoint_path, + cpu_offload=cpu_offload, + load_from_rank0=load_from_rank0, + ) - if cpu_offload: - block_weights_dict = split_block_weights(weights_dict) - if lazy_load: - model.blocks_weights.load({}) - else: - model.blocks_weights.load(block_weights_dict) - del block_weights_dict + if cpu_offload: + block_weights_dict = split_block_weights(weights_dict) + if lazy_load: + model.blocks_weights.load({}) + else: + model.blocks_weights.load(block_weights_dict) + del block_weights_dict + gc.collect() + + model.load_state_dict(weights_dict) + del weights_dict gc.collect() + else: + logger.info("[DummyModel] Skipping T5 weight loading, using random init") - model.load_state_dict(weights_dict) - del weights_dict - gc.collect() self.model = model if shard_fn is not None: self.model = shard_fn(self.model, sync_module_states=False) diff --git a/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py b/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py index 6da8f0532..697219c02 100755 --- a/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py +++ b/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py @@ -454,7 +454,7 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r class CLIPModel: - def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False): + def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False, dummy_model=False): self.dtype = dtype self.quantized = clip_quantized self.cpu_offload = cpu_offload @@ -470,8 +470,13 @@ def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantize pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme ) self.model = self.model.eval().requires_grad_(False) - weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual", load_from_rank0=load_from_rank0) - self.model.load_state_dict(weight_dict) + if not dummy_model: + weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual", load_from_rank0=load_from_rank0) + self.model.load_state_dict(weight_dict) + else: + from loguru import logger + + logger.info("[DummyModel] Skipping CLIP weight loading, using random init") def visual(self, videos): if self.cpu_offload: diff --git a/lightx2v/models/networks/base_model.py b/lightx2v/models/networks/base_model.py index aec653f14..f6f576f6e 100755 --- a/lightx2v/models/networks/base_model.py +++ b/lightx2v/models/networks/base_model.py @@ -11,7 +11,9 @@ import gc import glob +import json import os +import struct from abc import ABC, abstractmethod import torch @@ -25,6 +27,19 @@ from lightx2v.utils.utils import * from lightx2v_platform.base.global_var import AI_DEVICE +SAFETENSORS_DTYPE_MAP = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, +} + class BaseTransformerModel(CompiledMethodsMixin, ABC): """Base class for all transformer models. @@ -126,25 +141,93 @@ def _init_infer_class(self): """ pass + @staticmethod + def _read_safetensors_metadata(file_path): + """Read tensor metadata (names, shapes, dtypes) from safetensors file header without loading data.""" + with open(file_path, "rb") as f: + header_size = struct.unpack("