Skip to content

fix: load frozen autoencoder checkpoint on CPU to avoid DDP hangs#10

Open
DannyP0 wants to merge 1 commit into
NVIDIA-BioNeMo:mainfrom
DannyP0:main
Open

fix: load frozen autoencoder checkpoint on CPU to avoid DDP hangs#10
DannyP0 wants to merge 1 commit into
NVIDIA-BioNeMo:mainfrom
DannyP0:main

Conversation

@DannyP0
Copy link
Copy Markdown

@DannyP0 DannyP0 commented Apr 2, 2026

Summary

Load the frozen autoencoder via AutoEncoder.load_from_checkpoint(..., map_location=torch.device("cpu")) in Proteina.load_autoencoder, so weights are not restored onto CUDA before Lightning assigns each process to its rank-local GPU.

Problem

When training with multi-GPU DDP (e.g. CUDA_VISIBLE_DEVICES=… and Trainer(strategy="ddp", devices>1)), Proteina is constructed inside __init__ before torch.distributed / per-rank device binding is complete for all code paths.

Without map_location, checkpoint tensors are often placed on the default CUDA device (typically logical cuda:0). For every spawned training process that still sees the same visible device order, that is the same physical GPU. As a result:

Rank 0 and Rank 1 can both allocate a full VAE on one GPU.
Rank > 0 may later run the trainable stack on another GPU while still holding VAE state on the first, producing inconsistent device layouts.

This showed up as NCCL / training hangs and confusing nvidia-smi output (same PID on multiple GPUs).

How to reproduce (pre-fix)

The hang / incorrect multi-GPU memory layout showed up in downstream training that loads a frozen autoencoder inside Proteina (e.g. unconditional local-latents diffusion training with autoencoder_ckpt_path set), using multi-GPU DDP (Trainer with strategy="ddp" and devices>1).

Note: We did not see this issue when only running the first-stage VAE / autoencoder training with DDP. That path does not exercise the same AutoEncoder.load_from_checkpoint call inside Proteina during module construction, so reviewers should reproduce using the second-stage (latent-space) training entrypoint and config that enables the frozen VAE.

Minimal hints for maintainers:

  • Use the training script / Hydra config that builds Proteina with a non-null autoencoder_ckpt_path.
  • Run with two (or more) GPUs and DDP; optionally set CUDA_VISIBLE_DEVICES to a subset of GPUs.
  • Before this PR, observe distributed hang or abnormal nvidia-smi (e.g. one process using multiple GPUs). After this PR, training should start normally.

Solution

Deserialize the VAE checkpoint to CPU only. Lightning then moves the full LightningModule (including the frozen autoencoder submodule) to the correct device for each rank during normal setup. Training-time VAE inference remains on GPU; only the initial load path changes.

Testing

  • Single-GPU training (sanity).
  • Multi-GPU DDP (e.g. 2 GPUs): training starts without hang; one process per GPU with expected memory footprint.

Notes

Slight one-time overhead: CPU staging then copy to GPU when the module is moved to the training device—negligible relative to full training.

Alternatives considered

We also tried loading the checkpoint directly onto the training device with
map_location=torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', '0'))}")
In our setup this ran correctly and we did not observe a meaningful training throughput difference versus map_location="cpu" (the extra cost is dominated by normal training).

We still prefer map_location=torch.device("cpu") as the default fix because it avoids relying on LOCAL_RANK being set in every code path before Proteina.__init__ (the rank-0 entry process may not have distributed env vars set yet when it first constructs the module, depending on Lightning launcher behavior). CPU staging keeps behavior consistent across those paths without special-casing rank 0 vs workers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant