diff --git a/lightx2v/models/video_encoders/hf/qwen_image/vae.py b/lightx2v/models/video_encoders/hf/qwen_image/vae.py index 13a64bf35..db76b1740 100755 --- a/lightx2v/models/video_encoders/hf/qwen_image/vae.py +++ b/lightx2v/models/video_encoders/hf/qwen_image/vae.py @@ -2,6 +2,7 @@ import os import torch +import torch.distributed as dist from lightx2v.utils.envs import * from lightx2v_platform.base.global_var import AI_DEVICE @@ -23,6 +24,7 @@ def __init__(self, config): if self.is_layered: self.layers = config.get("layers", 4) + self.vae_decode_parallel = config.get("vae_decode_parallel", False) self.cpu_offload = config.get("vae_cpu_offload", config.get("cpu_offload", False)) if self.cpu_offload: self.device = torch.device("cpu") @@ -61,6 +63,69 @@ def _unpack_latents(latents, height, width, vae_scale_factor, layers=None): return latents + def _get_2d_grid(self, total_h, total_w, world_size): + best_h, best_w = 1, world_size + min_aspect_diff = float("inf") + for h in range(1, world_size + 1): + if world_size % h == 0: + w = world_size // h + if total_h % h == 0 and total_w % w == 0: + aspect_diff = abs((total_h / h) - (total_w / w)) + if aspect_diff < min_aspect_diff: + min_aspect_diff = aspect_diff + best_h, best_w = h, w + return best_h, best_w + + def _decode_dist(self, latents): + """Parallel 2D spatial decode. latents: (b, c, 1, lh, lw)""" + world_size = dist.get_world_size() + cur_rank = dist.get_rank() + total_h, total_w = latents.shape[3], latents.shape[4] + world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size) + cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w + chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w + padding, spatial_ratio = 2, 8 + + # Slice with overlap padding + if cur_rank_h == 0: + h_start, h_end = 0, chunk_h + 2 * padding + elif cur_rank_h == world_size_h - 1: + h_start, h_end = total_h - chunk_h - 2 * padding, total_h + else: + h_start, h_end = cur_rank_h * chunk_h - padding, (cur_rank_h + 1) * chunk_h + padding + + if cur_rank_w == 0: + w_start, w_end = 0, chunk_w + 2 * padding + elif cur_rank_w == world_size_w - 1: + w_start, w_end = total_w - chunk_w - 2 * padding, total_w + else: + w_start, w_end = cur_rank_w * chunk_w - padding, (cur_rank_w + 1) * chunk_w + padding + + chunk = latents[:, :, :, h_start:h_end, w_start:w_end].contiguous() + decoded = self.model.decode(chunk, return_dict=False)[0] # (b, c, 1, H', W') + + # Trim decoded padding + if cur_rank_h == 0: + dh_start, dh_end = 0, chunk_h * spatial_ratio + elif cur_rank_h == world_size_h - 1: + dh_start, dh_end = decoded.shape[3] - chunk_h * spatial_ratio, decoded.shape[3] + else: + dh_start, dh_end = padding * spatial_ratio, decoded.shape[3] - padding * spatial_ratio + + if cur_rank_w == 0: + dw_start, dw_end = 0, chunk_w * spatial_ratio + elif cur_rank_w == world_size_w - 1: + dw_start, dw_end = decoded.shape[4] - chunk_w * spatial_ratio, decoded.shape[4] + else: + dw_start, dw_end = padding * spatial_ratio, decoded.shape[4] - padding * spatial_ratio + + piece = decoded[:, :, :, dh_start:dh_end, dw_start:dw_end].contiguous() + full = [torch.empty_like(piece) for _ in range(world_size_h * world_size_w)] + dist.all_gather(full, piece) + + rows = [torch.cat([full[h_idx * world_size_w + w_idx] for w_idx in range(world_size_w)], dim=4) for h_idx in range(world_size_h)] + return torch.cat(rows, dim=3) + @torch.no_grad() def decode(self, latents, input_info): if self.cpu_offload: @@ -74,19 +139,20 @@ def decode(self, latents, input_info): latents_mean = torch.tensor(self.vae_latents_mean).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype) latents_std = 1.0 / torch.tensor(self.vae_latents_std).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean + + use_vae_decode_parallel = self.vae_decode_parallel and dist.is_initialized() and dist.get_world_size() > 1 + if self.is_layered: b, c, f, h, w = latents.shape - latents = latents[:, :, 1:] # remove the first frame as it is the orgin input - latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) - image = self.model.decode(latents, return_dict=False)[0] # (b f) c 1 h w + latents = latents[:, :, 1:].permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) + image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0] image = image.squeeze(2) image = self.image_processor.postprocess(image, output_type="pt" if input_info.return_result_tensor else "pil") - images = [] - for bidx in range(b): - images.append(image[bidx * f : (bidx + 1) * f]) + images = [image[bidx * f : (bidx + 1) * f] for bidx in range(b)] else: - images = self.model.decode(latents, return_dict=False)[0][:, :, 0] - images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil") + image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0] + images = self.image_processor.postprocess(image[:, :, 0], output_type="pt" if input_info.return_result_tensor else "pil") + if self.cpu_offload: self.model.to(torch.device("cpu")) torch_device_module.empty_cache()