Skip to content

update qwen vae#1015

Merged
llmc-reviewer merged 1 commit intomainfrom
vae
Apr 15, 2026
Merged

update qwen vae#1015
llmc-reviewer merged 1 commit intomainfrom
vae

Conversation

@helloyongyang
Copy link
Copy Markdown
Contributor

No description provided.

@llmc-reviewer llmc-reviewer merged commit 1f8b305 into main Apr 15, 2026
2 checks passed
@llmc-reviewer llmc-reviewer deleted the vae branch April 15, 2026 09:13
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements parallel 2D spatial decoding for the VAE model using torch.distributed. It introduces a vae_decode_parallel configuration option and adds logic to split latents into spatial chunks with overlapping padding, decode them in parallel, and gather the results. Review feedback suggests improving the grid calculation to ensure dimensions are perfectly divisible and large enough for padding, handling cases where a valid grid cannot be determined, and refactoring duplicated decoding logic across the layered and non-layered code paths.

Comment on lines +66 to +77
def _get_2d_grid(self, total_h, total_w, world_size):
best_h, best_w = 1, world_size
min_aspect_diff = float("inf")
for h in range(1, world_size + 1):
if world_size % h == 0:
w = world_size // h
if total_h % h == 0 and total_w % w == 0:
aspect_diff = abs((total_h / h) - (total_w / w))
if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff
best_h, best_w = h, w
return best_h, best_w
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _get_2d_grid function currently defaults to (1, world_size) if no grid is found that perfectly divides the latent dimensions. If the latent width is not divisible by world_size, the slicing logic in _decode_dist will result in gaps between spatial chunks, leading to corrupted output images. Additionally, the logic should verify that the resulting chunks are large enough to accommodate the required padding (i.e., total_h // h >= 2 * padding).

Suggested change
def _get_2d_grid(self, total_h, total_w, world_size):
best_h, best_w = 1, world_size
min_aspect_diff = float("inf")
for h in range(1, world_size + 1):
if world_size % h == 0:
w = world_size // h
if total_h % h == 0 and total_w % w == 0:
aspect_diff = abs((total_h / h) - (total_w / w))
if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff
best_h, best_w = h, w
return best_h, best_w
def _get_2d_grid(self, total_h, total_w, world_size, padding=2):
best_h, best_w = None, None
min_aspect_diff = float("inf")
for h in range(1, world_size + 1):
if world_size % h == 0:
w = world_size // h
if total_h % h == 0 and total_w % w == 0:
if total_h // h >= 2 * padding and total_w // w >= 2 * padding:
aspect_diff = abs((total_h / h) - (total_w / w))
if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff
best_h, best_w = h, w
return best_h, best_w

Comment on lines +84 to +87
world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size)
cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w
chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w
padding, spatial_ratio = 2, 8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _decode_dist method should handle cases where _get_2d_grid fails to find a valid split (returning None). Also, spatial_ratio is hardcoded to 8, which should ideally be retrieved from the configuration for consistency across different model versions.

Suggested change
world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size)
cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w
chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w
padding, spatial_ratio = 2, 8
padding, spatial_ratio = 2, self.config.get("vae_scale_factor", 8)
world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size, padding)
if world_size_h is None:
return self.model.decode(latents, return_dict=False)[0]
cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w
chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w

Comment on lines +148 to +153
image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0]
image = image.squeeze(2)
image = self.image_processor.postprocess(image, output_type="pt" if input_info.return_result_tensor else "pil")
images = []
for bidx in range(b):
images.append(image[bidx * f : (bidx + 1) * f])
images = [image[bidx * f : (bidx + 1) * f] for bidx in range(b)]
else:
images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil")
image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parallel decoding logic is duplicated in both the is_layered and non-layered branches. This can be refactored to improve maintainability by extracting the decoding step before the conditional post-processing blocks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants