Conversation
There was a problem hiding this comment.
Code Review
This pull request implements parallel 2D spatial decoding for the VAE model using torch.distributed. It introduces a vae_decode_parallel configuration option and adds logic to split latents into spatial chunks with overlapping padding, decode them in parallel, and gather the results. Review feedback suggests improving the grid calculation to ensure dimensions are perfectly divisible and large enough for padding, handling cases where a valid grid cannot be determined, and refactoring duplicated decoding logic across the layered and non-layered code paths.
| def _get_2d_grid(self, total_h, total_w, world_size): | ||
| best_h, best_w = 1, world_size | ||
| min_aspect_diff = float("inf") | ||
| for h in range(1, world_size + 1): | ||
| if world_size % h == 0: | ||
| w = world_size // h | ||
| if total_h % h == 0 and total_w % w == 0: | ||
| aspect_diff = abs((total_h / h) - (total_w / w)) | ||
| if aspect_diff < min_aspect_diff: | ||
| min_aspect_diff = aspect_diff | ||
| best_h, best_w = h, w | ||
| return best_h, best_w |
There was a problem hiding this comment.
The _get_2d_grid function currently defaults to (1, world_size) if no grid is found that perfectly divides the latent dimensions. If the latent width is not divisible by world_size, the slicing logic in _decode_dist will result in gaps between spatial chunks, leading to corrupted output images. Additionally, the logic should verify that the resulting chunks are large enough to accommodate the required padding (i.e., total_h // h >= 2 * padding).
| def _get_2d_grid(self, total_h, total_w, world_size): | |
| best_h, best_w = 1, world_size | |
| min_aspect_diff = float("inf") | |
| for h in range(1, world_size + 1): | |
| if world_size % h == 0: | |
| w = world_size // h | |
| if total_h % h == 0 and total_w % w == 0: | |
| aspect_diff = abs((total_h / h) - (total_w / w)) | |
| if aspect_diff < min_aspect_diff: | |
| min_aspect_diff = aspect_diff | |
| best_h, best_w = h, w | |
| return best_h, best_w | |
| def _get_2d_grid(self, total_h, total_w, world_size, padding=2): | |
| best_h, best_w = None, None | |
| min_aspect_diff = float("inf") | |
| for h in range(1, world_size + 1): | |
| if world_size % h == 0: | |
| w = world_size // h | |
| if total_h % h == 0 and total_w % w == 0: | |
| if total_h // h >= 2 * padding and total_w // w >= 2 * padding: | |
| aspect_diff = abs((total_h / h) - (total_w / w)) | |
| if aspect_diff < min_aspect_diff: | |
| min_aspect_diff = aspect_diff | |
| best_h, best_w = h, w | |
| return best_h, best_w |
| world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size) | ||
| cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w | ||
| chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w | ||
| padding, spatial_ratio = 2, 8 |
There was a problem hiding this comment.
The _decode_dist method should handle cases where _get_2d_grid fails to find a valid split (returning None). Also, spatial_ratio is hardcoded to 8, which should ideally be retrieved from the configuration for consistency across different model versions.
| world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size) | |
| cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w | |
| chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w | |
| padding, spatial_ratio = 2, 8 | |
| padding, spatial_ratio = 2, self.config.get("vae_scale_factor", 8) | |
| world_size_h, world_size_w = self._get_2d_grid(total_h, total_w, world_size, padding) | |
| if world_size_h is None: | |
| return self.model.decode(latents, return_dict=False)[0] | |
| cur_rank_h, cur_rank_w = cur_rank // world_size_w, cur_rank % world_size_w | |
| chunk_h, chunk_w = total_h // world_size_h, total_w // world_size_w |
| image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0] | ||
| image = image.squeeze(2) | ||
| image = self.image_processor.postprocess(image, output_type="pt" if input_info.return_result_tensor else "pil") | ||
| images = [] | ||
| for bidx in range(b): | ||
| images.append(image[bidx * f : (bidx + 1) * f]) | ||
| images = [image[bidx * f : (bidx + 1) * f] for bidx in range(b)] | ||
| else: | ||
| images = self.model.decode(latents, return_dict=False)[0][:, :, 0] | ||
| images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil") | ||
| image = self._decode_dist(latents) if use_vae_decode_parallel else self.model.decode(latents, return_dict=False)[0] |
No description provided.