From 6cbe0a144c251fd1d47e3b4cf3a8a0169473aca9 Mon Sep 17 00:00:00 2001 From: wangshankun Date: Tue, 14 Apr 2026 07:33:12 +0000 Subject: [PATCH 1/2] [feat] Add dummy model feature --- .../seko_talk/seko_talk_01_base_dummy.json | 17 +++ .../hf/seko_audio/audio_adapter.py | 2 +- .../hf/seko_audio/audio_encoder.py | 13 ++- .../models/input_encoders/hf/wan/t5/model.py | 35 +++--- .../hf/wan/xlm_roberta/model.py | 11 +- lightx2v/models/networks/base_model.py | 107 ++++++++++++++++-- lightx2v/models/networks/wan/audio_model.py | 20 ++++ .../models/runners/wan/wan_audio_runner.py | 14 ++- lightx2v/models/runners/wan/wan_runner.py | 4 + lightx2v/models/video_encoders/hf/wan/vae.py | 25 ++-- .../models/video_encoders/hf/wan/vae_2_2.py | 35 ++++-- .../seko_talk/run_seko_talk_01_base_dummy.sh | 21 ++++ 12 files changed, 245 insertions(+), 59 deletions(-) create mode 100644 configs/seko_talk/seko_talk_01_base_dummy.json create mode 100755 scripts/seko_talk/run_seko_talk_01_base_dummy.sh diff --git a/configs/seko_talk/seko_talk_01_base_dummy.json b/configs/seko_talk/seko_talk_01_base_dummy.json new file mode 100644 index 000000000..994434c2d --- /dev/null +++ b/configs/seko_talk/seko_talk_01_base_dummy.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 8, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "adaptive", + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + "sample_guide_scale": 1.0, + "sample_shift": 5, + "enable_cfg": false, + "cpu_offload": false, + "use_31_block": true, + "dummy_model": true +} diff --git a/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py b/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py index b44b4e2ee..ffebaa9c9 100755 --- a/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py +++ b/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py @@ -6,9 +6,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from diffusers.models.embeddings import TimestepEmbedding, Timesteps from einops import rearrange +from diffusers.models.embeddings import TimestepEmbedding, Timesteps from lightx2v_platform.base.global_var import AI_DEVICE diff --git a/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py b/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py index 5b727af44..4f2692a23 100755 --- a/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py +++ b/lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py @@ -1,15 +1,17 @@ import torch -from transformers import AutoFeatureExtractor, AutoModel +from loguru import logger +from transformers import AutoConfig, AutoFeatureExtractor, AutoModel from lightx2v.utils.envs import * from lightx2v_platform.base.global_var import AI_DEVICE class SekoAudioEncoderModel: - def __init__(self, model_path, audio_sr, cpu_offload): + def __init__(self, model_path, audio_sr, cpu_offload, dummy_model=False): self.model_path = model_path self.audio_sr = audio_sr self.cpu_offload = cpu_offload + self.dummy_model = dummy_model if self.cpu_offload: self.device = torch.device("cpu") else: @@ -18,7 +20,12 @@ def __init__(self, model_path, audio_sr, cpu_offload): def load(self): self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_path) - self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path) + if self.dummy_model: + logger.info("[DummyModel] Skipping audio encoder weight loading, using random init from config") + config = AutoConfig.from_pretrained(self.model_path) + self.audio_feature_encoder = AutoModel.from_config(config) + else: + self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path) self.audio_feature_encoder.to(self.device) self.audio_feature_encoder.eval() self.audio_feature_encoder.to(GET_DTYPE()) diff --git a/lightx2v/models/input_encoders/hf/wan/t5/model.py b/lightx2v/models/input_encoders/hf/wan/t5/model.py index 17adb176f..88fb3b453 100755 --- a/lightx2v/models/input_encoders/hf/wan/t5/model.py +++ b/lightx2v/models/input_encoders/hf/wan/t5/model.py @@ -792,6 +792,7 @@ def __init__( quant_scheme=None, lazy_load=False, load_from_rank0=False, + dummy_model=False, ): self.text_len = text_len self.dtype = dtype @@ -821,24 +822,28 @@ def __init__( .requires_grad_(False) ) - weights_dict = load_weights( - self.checkpoint_path, - cpu_offload=cpu_offload, - load_from_rank0=load_from_rank0, - ) + if not dummy_model: + weights_dict = load_weights( + self.checkpoint_path, + cpu_offload=cpu_offload, + load_from_rank0=load_from_rank0, + ) - if cpu_offload: - block_weights_dict = split_block_weights(weights_dict) - if lazy_load: - model.blocks_weights.load({}) - else: - model.blocks_weights.load(block_weights_dict) - del block_weights_dict + if cpu_offload: + block_weights_dict = split_block_weights(weights_dict) + if lazy_load: + model.blocks_weights.load({}) + else: + model.blocks_weights.load(block_weights_dict) + del block_weights_dict + gc.collect() + + model.load_state_dict(weights_dict) + del weights_dict gc.collect() + else: + logger.info("[DummyModel] Skipping T5 weight loading, using random init") - model.load_state_dict(weights_dict) - del weights_dict - gc.collect() self.model = model if shard_fn is not None: self.model = shard_fn(self.model, sync_module_states=False) diff --git a/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py b/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py index 6da8f0532..697219c02 100755 --- a/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py +++ b/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py @@ -454,7 +454,7 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r class CLIPModel: - def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False): + def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False, dummy_model=False): self.dtype = dtype self.quantized = clip_quantized self.cpu_offload = cpu_offload @@ -470,8 +470,13 @@ def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantize pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme ) self.model = self.model.eval().requires_grad_(False) - weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual", load_from_rank0=load_from_rank0) - self.model.load_state_dict(weight_dict) + if not dummy_model: + weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual", load_from_rank0=load_from_rank0) + self.model.load_state_dict(weight_dict) + else: + from loguru import logger + + logger.info("[DummyModel] Skipping CLIP weight loading, using random init") def visual(self, videos): if self.cpu_offload: diff --git a/lightx2v/models/networks/base_model.py b/lightx2v/models/networks/base_model.py index aec653f14..f6f576f6e 100755 --- a/lightx2v/models/networks/base_model.py +++ b/lightx2v/models/networks/base_model.py @@ -11,7 +11,9 @@ import gc import glob +import json import os +import struct from abc import ABC, abstractmethod import torch @@ -25,6 +27,19 @@ from lightx2v.utils.utils import * from lightx2v_platform.base.global_var import AI_DEVICE +SAFETENSORS_DTYPE_MAP = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, +} + class BaseTransformerModel(CompiledMethodsMixin, ABC): """Base class for all transformer models. @@ -126,25 +141,93 @@ def _init_infer_class(self): """ pass + @staticmethod + def _read_safetensors_metadata(file_path): + """Read tensor metadata (names, shapes, dtypes) from safetensors file header without loading data.""" + with open(file_path, "rb") as f: + header_size = struct.unpack(" Date: Tue, 14 Apr 2026 09:05:05 +0000 Subject: [PATCH 2/2] ci format --- lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py b/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py index ffebaa9c9..b44b4e2ee 100755 --- a/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py +++ b/lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py @@ -6,9 +6,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from diffusers.models.embeddings import TimestepEmbedding, Timesteps from einops import rearrange -from diffusers.models.embeddings import TimestepEmbedding, Timesteps from lightx2v_platform.base.global_var import AI_DEVICE