A Variational Autoencoder (VAE) training pipeline for frequency map reconstruction, designed for NERSC Perlmutter. See BeamVAE.md for the coordinate convention and physics motivation.
# Clone the repository
git clone https://github.com/ndwang/beam_vae.git $PSCRATCH/vae
cd $PSCRATCH/vae
# Load conda and create environment
ml load conda
conda env create -f environment.yml
conda activate vaepython -c "import torch; print(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')"# Train with default settings
python scripts/train.py
# Train with custom latent dimension
python scripts/train.py model.latent_dim=128
# Train residual VAE with custom beta
python scripts/train.py model=model/residual_vae2d.yaml training.beta=1e-5The pipeline uses a YAML-based configuration system with composable configs and CLI overrides.
configs/
├── default.yaml # Main config (references sub-configs)
├── model/
│ ├── vae2d.yaml # Standard VAE
│ └── residual_vae2d.yaml # Residual VAE
├── training/
│ └── default.yaml # Training hyperparameters
└── data/
├── frequency_maps.yaml # Min-max normalized data
└── frequency_maps_log.yaml # Log-transformed data
Model (configs/model/vae2d.yaml):
| Parameter | Default | Description |
|---|---|---|
input_channels |
15 | Number of input channels |
hidden_channels |
[32, 64, 128, 256, 512] | Encoder/decoder channel sizes |
latent_dim |
64 | Latent space dimension |
input_size |
64 | Spatial size (64x64) |
activation |
relu | Activation function |
dropout_rate |
0.0 | Dropout probability |
Training (configs/training/default.yaml):
| Parameter | Default | Description |
|---|---|---|
epochs |
300 | Number of training epochs |
batch_size |
256 | Batch size |
lr |
5e-4 | Learning rate |
weight_decay |
1e-4 | AdamW weight decay |
beta |
0.0 | KL divergence weight (β-VAE) |
val_split |
0.1 | Validation split ratio |
seed |
42 | Random seed |
checkpoint_freq |
50 | Save checkpoint every N epochs |
wandb.enabled |
false | Enable Weights & Biases logging |
wandb.project |
beam-vae | W&B project name |
wandb.offline |
true | Offline mode (for NERSC) |
Override any config value using dot notation:
# Single override
python scripts/train.py model.latent_dim=32
# Multiple overrides
python scripts/train.py model.latent_dim=32 training.lr=1e-4 training.epochs=500
# Switch sub-config
python scripts/train.py model=model/residual_vae2d.yaml
# Custom run name
python scripts/train.py run_name=my_experimentpython scripts/train.pyThis will:
- Load config from
configs/default.yaml - Create dataset from configured path
- Train the model
- Save outputs to
runs/<run_name>/
Each training run creates a directory with:
runs/<run_name>/
├── config.yaml # Full configuration (for reproducibility)
├── <run_name>.pth # Final model weights
├── <run_name>_best.pth # Best model checkpoint (lowest val loss)
├── <run_name>_epoch{N}.pth # Periodic checkpoints
├── <run_name>_history.csv # Training/validation loss history
└── wandb/ # W&B logs (if enabled)
History CSV Contents (<run_name>_history.csv):
| Column | Description |
|---|---|
epoch |
Epoch number |
train_total |
Training total loss (recon + β×KL) |
train_recon |
Training reconstruction loss |
train_kl |
Training KL divergence |
val_total |
Validation total loss |
val_recon |
Validation reconstruction loss |
val_kl |
Validation KL divergence |
Run Naming: Auto-generated run names use a concise format:
# Format: latent{dim}_beta{beta}_{YYMMDD}_{HHMM}
latent64_beta1e-05_260126_1430
The timestamp ensures uniqueness while keeping names readable.
Checkpoint Contents:
| Key | Description |
|---|---|
epoch |
Training epoch number |
model_state_dict |
Model weights |
optimizer_state_dict |
Optimizer state |
scheduler_state_dict |
Scheduler state |
train_loss |
Training total loss |
train_recon_loss |
Training reconstruction loss |
train_kl_loss |
Training KL divergence |
val_loss |
Validation total loss |
val_recon_loss |
Validation reconstruction loss |
val_kl_loss |
Validation KL divergence |
beta |
KL divergence weight (β-VAE) |
Resume training from a checkpoint if interrupted or to continue training:
# Resume from best checkpoint
python scripts/train.py --resume runs/my_run/vae_best.pth
# Resume from specific epoch checkpoint
python scripts/train.py --resume runs/my_run/vae_epoch100.pth
# Resume with modified config (e.g., more epochs)
python scripts/train.py --resume runs/my_run/vae_best.pth training.epochs=500The resume functionality:
- Restores model weights, optimizer state, and scheduler state
- Continues from the saved epoch number
- Preserves the best validation loss for checkpointing
- Warns if beta differs between checkpoint and current config
Training prints per-epoch metrics:
Epoch 1/300 | Train: 0.012345 | Val: 0.013456 | LR: 5.00e-04
Epoch 2/300 | Train: 0.011234 | Val: 0.012345 | LR: 5.00e-04
...
Track experiments, visualize metrics, and manage model artifacts with W&B.
conda activate vae
pip install wandb
wandb login # One-time setup (from login node with internet)# Enable W&B logging
python scripts/train.py training.wandb.enabled=true
# Customize W&B settings
python scripts/train.py \
training.wandb.enabled=true \
training.wandb.project=my-vae-experiments \
training.wandb.offline=falseW&B settings in configs/training/default.yaml:
| Parameter | Default | Description |
|---|---|---|
wandb.enabled |
false | Enable/disable W&B logging |
wandb.project |
beam-vae | W&B project name |
wandb.entity |
null | W&B team/username (null = default) |
wandb.offline |
true | Offline mode (sync later) |
wandb.tags |
[] | Optional tags for run organization |
wandb.notes |
null | Optional run description |
Use offline mode (default) to avoid internet access during training:
SLURM jobs: W&B logs are automatically synced at the end of each job.
Manual runs: Sync logs afterwards from login node:
# Sync all offline runs
./slurm/sync_wandb.sh
# Or sync specific run
wandb sync runs/<run_name>/wandb/offline-run-*View on W&B dashboard: Visit https://wandb.ai to see your synced runs.
W&B tracks per-epoch metrics:
train/total_loss- Total training losstrain/recon_loss- Reconstruction losstrain/kl_loss- KL divergenceval/total_loss- Validation total lossval/recon_loss- Validation reconstruction lossval/kl_loss- Validation KL divergencelearning_rate- Current learning rate
W&B logs checkpoint metadata (file paths and metrics) without uploading the actual checkpoint files:
- Best model: Path, epoch, and validation loss tracked in run summary
- Periodic checkpoints: Path logged at each checkpoint interval
- Checkpoint files remain local - easily accessible via file paths in W&B dashboard
This approach keeps W&B runs lightweight while maintaining full checkpoint traceability.
Independent of W&B, the trainer saves:
- Best model:
<run_name>_best.pth(updated when validation loss improves) - Periodic:
<run_name>_epoch{N}.pth(everycheckpoint_freqepochs, default 50) - Final model:
<run_name>.pth(end of training)
Adjust checkpoint frequency:
python scripts/train.py training.checkpoint_freq=25 # Save every 25 epochsUse the SLURM scan scripts for parallel sweeps:
# Edit slurm/submit_1d_scan.sh to configure your sweep, then:
sbatch slurm/submit_1d_scan.sh
# W&B logs are automatically synced at the end of the jobThen compare all runs on the W&B dashboard with interactive plots and parallel coordinates.
Submit SLURM jobs in the top directory. SLURM logs are written to logs/ (create this directory before first submission).
sbatch slurm/submit_single.shEdit RUN_PREFIX and OVERRIDES in the script to configure the run.
Run a sweep over a single parameter using 4 GPUs (1 node) in parallel:
sbatch slurm/submit_1d_scan.shEdit PARAM_NAME and PARAM_VALUES in the script.
Run a grid search over two parameters:
sbatch slurm/submit_2d_grid.shEdit PARAM1_* and PARAM2_* variables in the script.
#!/bin/bash
#SBATCH --job-name=vae_train
#SBATCH --time=03:00:00
#SBATCH --nodes=1
#SBATCH --gpus=1
#SBATCH --constraint=gpu
#SBATCH --qos=regular
#SBATCH --account=m5089
cd /pscratch/sd/$USER/vae
ml load conda
conda activate sc_surrogate
python scripts/train.py model.latent_dim=64 training.epochs=500python scripts/visualize_loss.py runs/<run_name>/<run_name>_history.csv --saveCreates a PNG with training/validation curves for total loss, reconstruction loss, and KL divergence.
python scripts/visualize_recon.py \
runs/<run_name>/<run_name>.pth \
/pscratch/sd/n/ndwang/frequency_maps/frequency_maps_minmax.npy \
--sample-index 0 \
--channels 0 1 2 3 4Creates a comparison plot showing original, reconstruction, and error for selected channels.
vae/
├── configs/ # YAML configuration files
│ ├── default.yaml
│ ├── model/
│ ├── training/
│ └── data/
├── scripts/ # Entry point scripts
│ ├── train.py # Main training script
│ ├── visualize_loss.py # Loss curve plotting
│ └── visualize_recon.py # Reconstruction visualization
├── slurm/ # NERSC job scripts
│ ├── submit_single.sh # Single training run
│ ├── submit_1d_scan.sh # 1D parameter sweep
│ ├── submit_2d_grid.sh # 2D grid search
│ └── sync_wandb.sh # Sync W&B for manual runs
├── logs/ # SLURM output logs (not in git)
├── src/ # Source code
│ ├── models/ # VAE architectures
│ │ ├── vae2d.py
│ │ └── residual_vae2d.py
│ ├── data/ # Dataset classes
│ │ └── dataset.py
│ ├── training/ # Training utilities
│ │ ├── trainer.py
│ │ └── losses.py
│ └── utils/ # Utilities
│ ├── config.py # Config loading with CLI overrides
│ ├── validation.py # Pydantic config schema validation
│ ├── activations.py
│ ├── logging.py # W&B callback classes
│ └── wandb_init.py # W&B initialization
├── data/ # Dataset files (not in git)
├── runs/ # Training outputs (not in git)
├── CLAUDE.md
├── README.md
├── pyproject.toml
└── requirements.txt
Standard convolutional VAE with:
- 5 encoder blocks with strided convolution downsampling
- Bottleneck with FC layers for μ and log(σ²)
- 5 decoder blocks with bilinear upsampling
- Sigmoid output activation
Enhanced VAE with residual connections:
- Residual blocks before each down/upsample operation
- Skip connections for better gradient flow
- ~2x more parameters than standard VAE