Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions configs/seko_talk/seko_talk_01_base_dummy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 8,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": true,
"dummy_model": true
}
13 changes: 10 additions & 3 deletions lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import torch
from transformers import AutoFeatureExtractor, AutoModel
from loguru import logger
from transformers import AutoConfig, AutoFeatureExtractor, AutoModel

from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE


class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr, cpu_offload):
def __init__(self, model_path, audio_sr, cpu_offload, dummy_model=False):
self.model_path = model_path
self.audio_sr = audio_sr
self.cpu_offload = cpu_offload
self.dummy_model = dummy_model
if self.cpu_offload:
self.device = torch.device("cpu")
else:
Expand All @@ -18,7 +20,12 @@ def __init__(self, model_path, audio_sr, cpu_offload):

def load(self):
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_path)
self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path)
if self.dummy_model:
logger.info("[DummyModel] Skipping audio encoder weight loading, using random init from config")
config = AutoConfig.from_pretrained(self.model_path)
self.audio_feature_encoder = AutoModel.from_config(config)
else:
self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path)
self.audio_feature_encoder.to(self.device)
self.audio_feature_encoder.eval()
self.audio_feature_encoder.to(GET_DTYPE())
Expand Down
35 changes: 20 additions & 15 deletions lightx2v/models/input_encoders/hf/wan/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def __init__(
quant_scheme=None,
lazy_load=False,
load_from_rank0=False,
dummy_model=False,
):
self.text_len = text_len
self.dtype = dtype
Expand Down Expand Up @@ -821,24 +822,28 @@ def __init__(
.requires_grad_(False)
)

weights_dict = load_weights(
self.checkpoint_path,
cpu_offload=cpu_offload,
load_from_rank0=load_from_rank0,
)
if not dummy_model:
weights_dict = load_weights(
self.checkpoint_path,
cpu_offload=cpu_offload,
load_from_rank0=load_from_rank0,
)

if cpu_offload:
block_weights_dict = split_block_weights(weights_dict)
if lazy_load:
model.blocks_weights.load({})
else:
model.blocks_weights.load(block_weights_dict)
del block_weights_dict
if cpu_offload:
block_weights_dict = split_block_weights(weights_dict)
if lazy_load:
model.blocks_weights.load({})
else:
model.blocks_weights.load(block_weights_dict)
del block_weights_dict
gc.collect()

model.load_state_dict(weights_dict)
del weights_dict
gc.collect()
else:
logger.info("[DummyModel] Skipping T5 weight loading, using random init")

model.load_state_dict(weights_dict)
del weights_dict
gc.collect()
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
Expand Down
11 changes: 8 additions & 3 deletions lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r


class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False):
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False, dummy_model=False):
self.dtype = dtype
self.quantized = clip_quantized
self.cpu_offload = cpu_offload
Expand All @@ -470,8 +470,13 @@ def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantize
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False)
weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual", load_from_rank0=load_from_rank0)
self.model.load_state_dict(weight_dict)
if not dummy_model:
weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual", load_from_rank0=load_from_rank0)
self.model.load_state_dict(weight_dict)
else:
from loguru import logger

logger.info("[DummyModel] Skipping CLIP weight loading, using random init")

def visual(self, videos):
if self.cpu_offload:
Expand Down
107 changes: 95 additions & 12 deletions lightx2v/models/networks/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

import gc
import glob
import json
import os
import struct
from abc import ABC, abstractmethod

import torch
Expand All @@ -25,6 +27,19 @@
from lightx2v.utils.utils import *
from lightx2v_platform.base.global_var import AI_DEVICE

SAFETENSORS_DTYPE_MAP = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}


class BaseTransformerModel(CompiledMethodsMixin, ABC):
"""Base class for all transformer models.
Expand Down Expand Up @@ -126,25 +141,93 @@ def _init_infer_class(self):
"""
pass

@staticmethod
def _read_safetensors_metadata(file_path):
"""Read tensor metadata (names, shapes, dtypes) from safetensors file header without loading data."""
with open(file_path, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
header = json.loads(header_json)
tensors = {}
for key, info in header.items():
if key == "__metadata__":
continue
tensors[key] = {"shape": info["shape"], "dtype": info["dtype"]}
return tensors

def _load_dummy_ckpt(self, unified_dtype, sensitive_layer):
"""Generate random weight dict by reading safetensors metadata, without loading actual pretrained data.

When dummy_model is enabled, this reads only the file headers to determine tensor names/shapes,
then allocates random tensors on self.device (derived from AI_DEVICE / cpu_offload).
"""
dummy_device = str(self.device)
logger.info(f"[DummyModel] Generating random weights on device={dummy_device}")

if self.config.get("dit_original_ckpt", None):
safetensors_path = self.config["dit_original_ckpt"]
elif self.config.get("dit_quantized_ckpt", None) and self.dit_quantized:
safetensors_path = self.config["dit_quantized_ckpt"]
else:
safetensors_path = self.model_path

if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
Comment on lines +174 to +177
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If safetensors_path is a directory but contains no .safetensors files, safetensors_files will be an empty list. This will result in an empty weight_dict, which might cause silent failures or uninitialized parameters later in the pipeline. It's better to check if any files were found and raise an informative error if not.


remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
preserve_keys = self.preserved_keys if hasattr(self, "preserved_keys") else None

weight_dict = {}
for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == file_path:
continue
logger.info(f"[DummyModel] Reading metadata from {file_path}")
tensors_meta = self._read_safetensors_metadata(file_path)
for key, meta in tensors_meta.items():
if any(rk in key for rk in remove_keys):
continue
if preserve_keys is not None and not any(pk in key for pk in preserve_keys):
continue
shape = meta["shape"]
st_dtype_str = meta["dtype"]
if unified_dtype or all(s not in key for s in sensitive_layer):
dtype = GET_DTYPE()
else:
dtype = GET_SENSITIVE_DTYPE()
original_dtype = SAFETENSORS_DTYPE_MAP.get(st_dtype_str)
if original_dtype is not None and not original_dtype.is_floating_point:
dtype = original_dtype
weight_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.randn with default parameters (mean=0, std=1) for dummy weight initialization can lead to numerical instability or activations exploding in deep networks, potentially causing NaNs during inference. For transformer models, a smaller standard deviation (e.g., 0.02) is generally safer and more representative of actual weight distributions.

Suggested change
weight_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)
weight_dict[key] = (torch.randn(shape, dtype=dtype, device=dummy_device) * 0.02) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)


return weight_dict

def _init_weights(self, weight_dict=None):
unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
# Some layers run with float32 to achieve high accuracy
sensitive_layer = self.sensitive_layer
if weight_dict is None:
is_weight_loader = self._should_load_weights()
if is_weight_loader:
if not self.dit_quantized:
# Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
# Load quantized weights
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
if self.config.get("dummy_model", False):
weight_dict = self._load_dummy_ckpt(unified_dtype, sensitive_layer)
if hasattr(self, "_load_adapter_ckpt"):
weight_dict.update(self._load_adapter_ckpt())
else:
is_weight_loader = self._should_load_weights()
if is_weight_loader:
if not self.dit_quantized:
# Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
# Load quantized weights
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)

if (self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False)) or (hasattr(self, "use_tp") and self.use_tp):
weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)
if (self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False)) or (hasattr(self, "use_tp") and self.use_tp):
weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)

if hasattr(self, "_load_adapter_ckpt"):
weight_dict.update(self._load_adapter_ckpt())
if hasattr(self, "_load_adapter_ckpt"):
weight_dict.update(self._load_adapter_ckpt())

self.original_weight_dict = weight_dict
else:
Expand Down
20 changes: 20 additions & 0 deletions lightx2v/models/networks/wan/audio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,26 @@ def _load_adapter_ckpt(self):
adapter_model_name = "audio_adapter_model.safetensors"
self.config["adapter_model_path"] = os.path.join(self.config["model_path"], adapter_model_name)

if self.config.get("dummy_model", False):
from lightx2v.models.networks.base_model import SAFETENSORS_DTYPE_MAP, BaseTransformerModel

dummy_device = str(self.device)
logger.info(f"[DummyModel] Generating random adapter weights on device={dummy_device}")
tensors_meta = BaseTransformerModel._read_safetensors_metadata(self.config["adapter_model_path"])
adapter_weights_dict = {}
from lightx2v.utils.envs import GET_DTYPE

for key, meta in tensors_meta.items():
if "audio" in key:
continue
shape = meta["shape"]
dtype = GET_DTYPE()
original_dtype = SAFETENSORS_DTYPE_MAP.get(meta["dtype"])
if original_dtype is not None and not original_dtype.is_floating_point:
dtype = original_dtype
adapter_weights_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the base model, using torch.randn with unit variance for dummy weights can cause numerical issues. Scaling by a factor like 0.02 is recommended for more stable dummy initialization.

Suggested change
adapter_weights_dict[key] = torch.randn(shape, dtype=dtype, device=dummy_device) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)
adapter_weights_dict[key] = (torch.randn(shape, dtype=dtype, device=dummy_device) * 0.02) if dtype.is_floating_point else torch.zeros(shape, dtype=dtype, device=dummy_device)

return adapter_weights_dict

adapter_offload = self.config.get("cpu_offload", False)
load_from_rank0 = self.config.get("load_from_rank0", False)
adapter_weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0)
Expand Down
14 changes: 10 additions & 4 deletions lightx2v/models/runners/wan/wan_audio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,8 @@ def load_transformer(self):
def load_audio_encoder(self):
audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large"))
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload)
dummy_model = self.config.get("dummy_model", False)
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, dummy_model=dummy_model)
return model

def load_audio_adapter(self):
Expand All @@ -830,9 +831,12 @@ def load_audio_adapter(self):
)

audio_adapter.to(device)
load_from_rank0 = self.config.get("load_from_rank0", False)
weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=audio_adapter_offload, remove_key="ca", load_from_rank0=load_from_rank0)
audio_adapter.load_state_dict(weights_dict, strict=False)
if not self.config.get("dummy_model", False):
load_from_rank0 = self.config.get("load_from_rank0", False)
weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=audio_adapter_offload, remove_key="ca", load_from_rank0=load_from_rank0)
audio_adapter.load_state_dict(weights_dict, strict=False)
else:
logger.info("[DummyModel] Skipping audio adapter weight loading, using random init")
return audio_adapter.to(dtype=GET_DTYPE())

def load_model(self):
Expand Down Expand Up @@ -941,6 +945,7 @@ def load_vae_decoder(self):
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
"dummy_model": self.config.get("dummy_model", False),
}
vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder
Expand All @@ -957,6 +962,7 @@ def load_vae_encoder(self):
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
"dummy_model": self.config.get("dummy_model", False),
}
if self.config.task not in ["i2v", "s2v", "rs2v"]:
return None
Expand Down
4 changes: 4 additions & 0 deletions lightx2v/models/runners/wan/wan_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def load_image_encoder(self):
cpu_offload=clip_offload,
use_31_block=self.config.get("use_31_block", True),
load_from_rank0=self.config.get("load_from_rank0", False),
dummy_model=self.config.get("dummy_model", False),
)

return image_encoder
Expand Down Expand Up @@ -162,6 +163,7 @@ def load_text_encoder(self):
quant_scheme=t5_quant_scheme,
load_from_rank0=self.config.get("load_from_rank0", False),
lazy_load=self.config.get("t5_lazy_load", False),
dummy_model=self.config.get("dummy_model", False),
)
text_encoders = [text_encoder]
return text_encoders
Expand Down Expand Up @@ -190,6 +192,7 @@ def load_vae_encoder(self):
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
"use_lightvae": self.config.get("use_lightvae", False),
"dummy_model": self.config.get("dummy_model", False),
}
if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v", "rs2v"]:
return None
Expand All @@ -213,6 +216,7 @@ def load_vae_decoder(self):
"use_lightvae": self.config.get("use_lightvae", False),
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
"dummy_model": self.config.get("dummy_model", False),
}
if self.config.get("use_tae", False):
tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name)
Expand Down
Loading
Loading