-
Notifications
You must be signed in to change notification settings - Fork 186
update qwen vae #1015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update qwen vae #1015
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |||||||||||||||||||||
| import os | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
| import torch.distributed as dist | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from lightx2v.utils.envs import * | ||||||||||||||||||||||
| from lightx2v_platform.base.global_var import AI_DEVICE | ||||||||||||||||||||||
|
|
@@ -23,6 +24,7 @@ def __init__(self, config): | |||||||||||||||||||||
| if self.is_layered: | ||||||||||||||||||||||
| self.layers = config.get("layers", 4) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| self.vae_decode_parallel = config.get("vae_decode_parallel", False) | ||||||||||||||||||||||
| self.cpu_offload = config.get("vae_cpu_offload", config.get("cpu_offload", False)) | ||||||||||||||||||||||
| if self.cpu_offload: | ||||||||||||||||||||||
| self.device = torch.device("cpu") | ||||||||||||||||||||||
|
|
@@ -61,6 +63,69 @@ def _unpack_latents(latents, height, width, vae_scale_factor, layers=None): | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| return latents | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _get_2d_grid(self, total_h, total_w, world_size): | ||||||||||||||||||||||
| best_h, best_w = 1, world_size | ||||||||||||||||||||||
| min_aspect_diff = float("inf") | ||||||||||||||||||||||
| for h in range(1, world_size + 1): | ||||||||||||||||||||||
| if world_size % h == 0: | ||||||||||||||||||||||
| w = world_size // h | ||||||||||||||||||||||
| if total_h % h == 0 and total_w % w == 0: | ||||||||||||||||||||||
| aspect_diff = abs((total_h / h) - (total_w / w)) | ||||||||||||||||||||||
| if aspect_diff < min_aspect_diff: | ||||||||||||||||||||||
| min_aspect_diff = aspect_diff | ||||||||||||||||||||||
| best_h, best_w = h, w | ||||||||||||||||||||||
| return best_h, best_w | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _decode_dist(self, latents): | ||||||||||||||||||||||
| """Parallel 2D spatial decode. latents: (b, c, 1, lh, lw)""" | ||||||||||||||||||||||
| world_size = dist.get_world_size() | ||||||||||||||||||||||
| cur_rank = dist.get_rank() | ||||||||||||||||||||||
| total_h, total_w = latents.shape[3], latents.shape[4] | ||||||||||||||||||||||
| world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size) | ||||||||||||||||||||||
| cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w | ||||||||||||||||||||||
| chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w | ||||||||||||||||||||||
| padding, spatial_ratio = 2, 8 | ||||||||||||||||||||||
|
Comment on lines
+84
to
+87
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Slice with overlap padding | ||||||||||||||||||||||
| if cur_rank_h == 0: | ||||||||||||||||||||||
| h_start, h_end = 0, chunk_h + 2 * padding | ||||||||||||||||||||||
| elif cur_rank_h == world_size_h - 1: | ||||||||||||||||||||||
| h_start, h_end = total_h - chunk_h - 2 * padding, total_h | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| h_start, h_end = cur_rank_h * chunk_h - padding, (cur_rank_h + 1) * chunk_h + padding | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if cur_rank_w == 0: | ||||||||||||||||||||||
| w_start, w_end = 0, chunk_w + 2 * padding | ||||||||||||||||||||||
| elif cur_rank_w == world_size_w - 1: | ||||||||||||||||||||||
| w_start, w_end = total_w - chunk_w - 2 * padding, total_w | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| w_start, w_end = cur_rank_w * chunk_w - padding, (cur_rank_w + 1) * chunk_w + padding | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| chunk = latents[:, :, :, h_start:h_end, w_start:w_end].contiguous() | ||||||||||||||||||||||
| decoded = self.model.decode(chunk, return_dict=False)[0] # (b, c, 1, H', W') | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Trim decoded padding | ||||||||||||||||||||||
| if cur_rank_h == 0: | ||||||||||||||||||||||
| dh_start, dh_end = 0, chunk_h * spatial_ratio | ||||||||||||||||||||||
| elif cur_rank_h == world_size_h - 1: | ||||||||||||||||||||||
| dh_start, dh_end = decoded.shape[3] - chunk_h * spatial_ratio, decoded.shape[3] | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| dh_start, dh_end = padding * spatial_ratio, decoded.shape[3] - padding * spatial_ratio | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if cur_rank_w == 0: | ||||||||||||||||||||||
| dw_start, dw_end = 0, chunk_w * spatial_ratio | ||||||||||||||||||||||
| elif cur_rank_w == world_size_w - 1: | ||||||||||||||||||||||
| dw_start, dw_end = decoded.shape[4] - chunk_w * spatial_ratio, decoded.shape[4] | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| dw_start, dw_end = padding * spatial_ratio, decoded.shape[4] - padding * spatial_ratio | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| piece = decoded[:, :, :, dh_start:dh_end, dw_start:dw_end].contiguous() | ||||||||||||||||||||||
| full = [torch.empty_like(piece) for _ in range(world_size_h * world_size_w)] | ||||||||||||||||||||||
| dist.all_gather(full, piece) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| rows = [torch.cat([full[h_idx * world_size_w + w_idx] for w_idx in range(world_size_w)], dim=4) for h_idx in range(world_size_h)] | ||||||||||||||||||||||
| return torch.cat(rows, dim=3) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @torch.no_grad() | ||||||||||||||||||||||
| def decode(self, latents, input_info): | ||||||||||||||||||||||
| if self.cpu_offload: | ||||||||||||||||||||||
|
|
@@ -74,19 +139,20 @@ def decode(self, latents, input_info): | |||||||||||||||||||||
| latents_mean = torch.tensor(self.vae_latents_mean).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype) | ||||||||||||||||||||||
| latents_std = 1.0 / torch.tensor(self.vae_latents_std).view(1, self.latent_channels, 1, 1, 1).to(latents.device, latents.dtype) | ||||||||||||||||||||||
| latents = latents / latents_std + latents_mean | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| use_vae_decode_parallel = self.vae_decode_parallel and dist.is_initialized() and dist.get_world_size() > 1 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if self.is_layered: | ||||||||||||||||||||||
| b, c, f, h, w = latents.shape | ||||||||||||||||||||||
| latents = latents[:, :, 1:] # remove the first frame as it is the orgin input | ||||||||||||||||||||||
| latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) | ||||||||||||||||||||||
| image = self.model.decode(latents, return_dict=False)[0] # (b f) c 1 h w | ||||||||||||||||||||||
| latents = latents[:, :, 1:].permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) | ||||||||||||||||||||||
| image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0] | ||||||||||||||||||||||
| image = image.squeeze(2) | ||||||||||||||||||||||
| image = self.image_processor.postprocess(image, output_type="pt" if input_info.return_result_tensor else "pil") | ||||||||||||||||||||||
| images = [] | ||||||||||||||||||||||
| for bidx in range(b): | ||||||||||||||||||||||
| images.append(image[bidx * f : (bidx + 1) * f]) | ||||||||||||||||||||||
| images = [image[bidx * f : (bidx + 1) * f] for bidx in range(b)] | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| images = self.model.decode(latents, return_dict=False)[0][:, :, 0] | ||||||||||||||||||||||
| images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil") | ||||||||||||||||||||||
| image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0] | ||||||||||||||||||||||
|
Comment on lines
+148
to
+153
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||
| images = self.image_processor.postprocess(image[:, :, 0], output_type="pt" if input_info.return_result_tensor else "pil") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if self.cpu_offload: | ||||||||||||||||||||||
| self.model.to(torch.device("cpu")) | ||||||||||||||||||||||
| torch_device_module.empty_cache() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_get_2d_gridfunction currently defaults to(1, world_size)if no grid is found that perfectly divides the latent dimensions. If the latent width is not divisible byworld_size, the slicing logic in_decode_distwill result in gaps between spatial chunks, leading to corrupted output images. Additionally, the logic should verify that the resulting chunks are large enough to accommodate the required padding (i.e.,total_h // h >= 2 * padding).