fix: load frozen autoencoder checkpoint on CPU to avoid DDP hangs#10
Open
DannyP0 wants to merge 1 commit into
Open
fix: load frozen autoencoder checkpoint on CPU to avoid DDP hangs#10DannyP0 wants to merge 1 commit into
DannyP0 wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Load the frozen autoencoder via
AutoEncoder.load_from_checkpoint(..., map_location=torch.device("cpu"))inProteina.load_autoencoder, so weights are not restored onto CUDA before Lightning assigns each process to its rank-local GPU.Problem
When training with multi-GPU DDP (e.g. CUDA_VISIBLE_DEVICES=… and Trainer(strategy="ddp", devices>1)), Proteina is constructed inside
__init__before torch.distributed / per-rank device binding is complete for all code paths.Without
map_location, checkpoint tensors are often placed on the default CUDA device (typically logicalcuda:0). For every spawned training process that still sees the same visible device order, that is the same physical GPU. As a result:Rank 0andRank 1can both allocate a full VAE on one GPU.Rank > 0 may later run the trainable stack on another GPU while still holding VAE state on the first, producing inconsistent device layouts.
This showed up as NCCL / training hangs and confusing nvidia-smi output (same PID on multiple GPUs).
How to reproduce (pre-fix)
The hang / incorrect multi-GPU memory layout showed up in downstream training that loads a frozen autoencoder inside
Proteina(e.g. unconditional local-latents diffusion training withautoencoder_ckpt_pathset), using multi-GPU DDP (Trainerwithstrategy="ddp"anddevices>1).Note: We did not see this issue when only running the first-stage VAE / autoencoder training with DDP. That path does not exercise the same
AutoEncoder.load_from_checkpointcall insideProteinaduring module construction, so reviewers should reproduce using the second-stage (latent-space) training entrypoint and config that enables the frozen VAE.Minimal hints for maintainers:
Proteinawith a non-nullautoencoder_ckpt_path.CUDA_VISIBLE_DEVICESto a subset of GPUs.nvidia-smi(e.g. one process using multiple GPUs). After this PR, training should start normally.Solution
Deserialize the VAE checkpoint to CPU only. Lightning then moves the full LightningModule (including the frozen autoencoder submodule) to the correct device for each rank during normal setup. Training-time VAE inference remains on GPU; only the initial load path changes.
Testing
Notes
Slight one-time overhead: CPU staging then copy to GPU when the module is moved to the training device—negligible relative to full training.
Alternatives considered
We also tried loading the checkpoint directly onto the training device with
map_location=torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', '0'))}")In our setup this ran correctly and we did not observe a meaningful training throughput difference versus
map_location="cpu"(the extra cost is dominated by normal training).We still prefer
map_location=torch.device("cpu")as the default fix because it avoids relying on LOCAL_RANK being set in every code path beforeProteina.__init__(the rank-0 entry process may not have distributed env vars set yet when it first constructs the module, depending on Lightning launcher behavior). CPU staging keeps behavior consistent across those paths without special-casing rank 0 vs workers.