Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 74 additions & 8 deletions lightx2v/models/video_encoders/hf/qwen_image/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

import torch
import torch.distributed as dist

from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
Expand All @@ -23,6 +24,7 @@ def __init__(self, config):
if self.is_layered:
self.layers = config.get("layers", 4)

self.vae_decode_parallel = config.get("vae_decode_parallel", False)
self.cpu_offload = config.get("vae_cpu_offload", config.get("cpu_offload", False))
if self.cpu_offload:
self.device = torch.device("cpu")
Expand Down Expand Up @@ -61,6 +63,69 @@ def _unpack_latents(latents, height, width, vae_scale_factor, layers=None):

return latents

def _get_2d_grid(self, total_h, total_w, world_size):
best_h, best_w = 1, world_size
min_aspect_diff = float("inf")
for h in range(1, world_size + 1):
if world_size % h == 0:
w = world_size // h
if total_h % h == 0 and total_w % w == 0:
aspect_diff = abs((total_h / h) - (total_w / w))
if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff
best_h, best_w = h, w
return best_h, best_w
Comment on lines +66 to +77
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _get_2d_grid function currently defaults to (1, world_size) if no grid is found that perfectly divides the latent dimensions. If the latent width is not divisible by world_size, the slicing logic in _decode_dist will result in gaps between spatial chunks, leading to corrupted output images. Additionally, the logic should verify that the resulting chunks are large enough to accommodate the required padding (i.e., total_h // h >= 2 * padding).

Suggested change
def _get_2d_grid(self, total_h, total_w, world_size):
best_h, best_w = 1, world_size
min_aspect_diff = float("inf")
for h in range(1, world_size + 1):
if world_size % h == 0:
w = world_size // h
if total_h % h == 0 and total_w % w == 0:
aspect_diff = abs((total_h / h) - (total_w / w))
if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff
best_h, best_w = h, w
return best_h, best_w
def _get_2d_grid(self, total_h, total_w, world_size, padding=2):
best_h, best_w = None, None
min_aspect_diff = float("inf")
for h in range(1, world_size + 1):
if world_size % h == 0:
w = world_size // h
if total_h % h == 0 and total_w % w == 0:
if total_h // h >= 2 * padding and total_w // w >= 2 * padding:
aspect_diff = abs((total_h / h) - (total_w / w))
if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff
best_h, best_w = h, w
return best_h, best_w


def _decode_dist(self, latents):
"""Parallel 2D spatial decode. latents: (b, c, 1, lh, lw)"""
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
total_h, total_w = latents.shape[3], latents.shape[4]
world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size)
cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w
chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w
padding, spatial_ratio = 2, 8
Comment on lines +84 to +87
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _decode_dist method should handle cases where _get_2d_grid fails to find a valid split (returning None). Also, spatial_ratio is hardcoded to 8, which should ideally be retrieved from the configuration for consistency across different model versions.

Suggested change
world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size)
cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w
chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w
padding, spatial_ratio = 2, 8
padding, spatial_ratio = 2, self.config.get("vae_scale_factor", 8)
world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size, padding)
if world_size_h is None:
return self.model.decode(latents, return_dict=False)[0]
cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w
chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w


# Slice with overlap padding
if cur_rank_h == 0:
h_start, h_end = 0, chunk_h + 2 * padding
elif cur_rank_h == world_size_h - 1:
h_start, h_end = total_h - chunk_h - 2 * padding, total_h
else:
h_start, h_end = cur_rank_h * chunk_h - padding, (cur_rank_h + 1) * chunk_h + padding

if cur_rank_w == 0:
w_start, w_end = 0, chunk_w + 2 * padding
elif cur_rank_w == world_size_w - 1:
w_start, w_end = total_w - chunk_w - 2 * padding, total_w
else:
w_start, w_end = cur_rank_w * chunk_w - padding, (cur_rank_w + 1) * chunk_w + padding

chunk = latents[:, :, :, h_start:h_end, w_start:w_end].contiguous()
decoded = self.model.decode(chunk, return_dict=False)[0] # (b, c, 1, H', W')

# Trim decoded padding
if cur_rank_h == 0:
dh_start, dh_end = 0, chunk_h * spatial_ratio
elif cur_rank_h == world_size_h - 1:
dh_start, dh_end = decoded.shape[3] - chunk_h * spatial_ratio, decoded.shape[3]
else:
dh_start, dh_end = padding * spatial_ratio, decoded.shape[3] - padding * spatial_ratio

if cur_rank_w == 0:
dw_start, dw_end = 0, chunk_w * spatial_ratio
elif cur_rank_w == world_size_w - 1:
dw_start, dw_end = decoded.shape[4] - chunk_w * spatial_ratio, decoded.shape[4]
else:
dw_start, dw_end = padding * spatial_ratio, decoded.shape[4] - padding * spatial_ratio

piece = decoded[:, :, :, dh_start:dh_end, dw_start:dw_end].contiguous()
full = [torch.empty_like(piece) for _ in range(world_size_h * world_size_w)]
dist.all_gather(full, piece)

rows = [torch.cat([full[h_idx * world_size_w + w_idx] for w_idx in range(world_size_w)], dim=4) for h_idx in range(world_size_h)]
return torch.cat(rows, dim=3)

@torch.no_grad()
def decode(self, latents, input_info):
if self.cpu_offload:
Expand All @@ -74,19 +139,20 @@ def decode(self, latents, input_info):
latents_mean = torch.tensor(self.vae_latents_mean).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.vae_latents_std).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean

use_vae_decode_parallel = self.vae_decode_parallel and dist.is_initialized() and dist.get_world_size() > 1

if self.is_layered:
b, c, f, h, w = latents.shape
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
image = self.model.decode(latents, return_dict=False)[0] # (b f) c 1 h w
latents = latents[:, :, 1:].permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0]
image = image.squeeze(2)
image = self.image_processor.postprocess(image, output_type="pt" if input_info.return_result_tensor else "pil")
images = []
for bidx in range(b):
images.append(image[bidx * f : (bidx + 1) * f])
images = [image[bidx * f : (bidx + 1) * f] for bidx in range(b)]
else:
images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil")
image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0]
Comment on lines +148 to +153
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parallel decoding logic is duplicated in both the is_layered and non-layered branches. This can be refactored to improve maintainability by extracting the decoding step before the conditional post-processing blocks.

images = self.image_processor.postprocess(image[:, :, 0], output_type="pt" if input_info.return_result_tensor else "pil")

if self.cpu_offload:
self.model.to(torch.device("cpu"))
torch_device_module.empty_cache()
Expand Down
Loading