Let's UNITE! Single-stage training, Unified model for Tokenization & Generation
Paper | Project Page
This is the code repository of the paper:
End-to-End Training for Unified Tokenization and Latent Denoising (arXiv 2026)
Shivam Duggal*, Xingjian Bai*, Zongze Wu, Richard Zhang, Eli Shechtman, Antonio Torralba, Phillip Isola, William T. Freeman
MIT, Adobe
Abstract
Approach Overview
Setup
Datasets
Training
Pretrained Checkpoints
Evaluation
Citation
# Clone and install dependencies
mamba create -n unite python=3.10 -y
mamba activate unite
pip install uv
# Install PyTorch 2.10.0 with CUDA 12.8 or higher # or your own cuda version
uv pip install -r requirements.txtRequired packages: torch>=2.1, torchvision, einops, torchdiffeq, lpips, pyyaml, clean-fid.
Main requirement is to have a torch version which supports Muon optimizer.
For FID evaluation on Imagenet 256 x 256:
wget https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth
wget https://raw.githubusercontent.com/LTH14/JiT/main/fid_stats/jit_in256_stats.npz
export INCEPTION_WEIGHTS=/path/to/inception_weights
export IN256_FID_STATS=/path/to/in256_fid_stats
Training uses the ImageNet-1K dataset (ILSVRC2012). Set the DATA_PATH environment variable to the training split:
export DATA_PATH=/path/to/imagenet/trainFID evaluation uses precomputed ImageNet statistics from clean-fid.
For our results on the molecules dataset, please refer to the paper.
Single-node (with 8 GPUs) training procedure on Imagenet 256x256:
# Using the launch script (auto-detects GPUs and data path)
bash run_scripts/train.sh
# Or directly with torchrun
torchrun --nproc_per_node=8 main_train.py \
--config configs/imagenet.yaml \
--data-path $DATA_PATH \
--experiment-name lets-uniteTo resume from a checkpoint, additionally provide the path of the checkpoint with argument: --ckpt.
Key config fields (see configs/imagenet.yaml):
| Field | Default | Description |
|---|---|---|
total_batch_size |
1024 | Effective batch size across all GPUs and accumulation steps |
grad_accum_steps |
2 | Gradient accumulation steps (batch/GPU = total / world_size / accum) |
flow_steps_per_recon |
3 | Flow matching steps per reconstruction forward pass |
flow_mini_batch |
4 | Chunk size for flow steps (controls peak memory) |
torch_compile_decoder |
true | torch.compile the decoder for faster training |
Evaluation during training: FID is evaluated automatically at the end of each fid_epoch interval (default: every 40 epochs). The evaluation uses adaptive CFG scale search: starting from the current best scale, it tests neighboring scales and updates the best. At fid_sweep_epoch intervals (default: 120 epochs), a more comprehensive sweep is performed across multiple CFG intervals and normalization orders.
To reproduce the paper results for UNITE-B (with 3 flow mini batches per each reconstruction step)on a single-node on ImageNet-1K 256×256, use the config
configs/imagenet.yaml.
| Architecture | Base encoder (130.6M params) + Base decoder (86.2M params) |
| Total params | 217.6M (all trainable) |
| Hardware | 1 node × 8 NVIDIA H200 (140 GB each) |
| Effective batch size | 1024 (64 per GPU × 2 gradient accumulation steps) |
| Precision | BF16 mixed precision |
| Optimizer | Muon |
| Training speed | ~26 min/epoch |
UNITE has an advesarial nature due to joint optimization of reconstruction and denoising losses. See paper Sec. 3.3 for more details. The training curves should be similar to the following graph:
FID and Inception Score (IS) are computed with adaptive CFG scale search — starting from the current best scale, neighboring scales are evaluated and the best is kept. All evaluations use the EMA model.
| Model | Encoder | Decoder | Total Params | Epochs | FID-50K | Checkpoint |
|---|---|---|---|---|---|---|
| United-B | Base | Base | 217.6M | 240 | 2.12 | Download Link |
| United-L | Large | Base | 589.0M | 120 | 1.73 | Download Link |
Checkpoint includes model weights (encoder + decoder), EMA state, optimizer state, and scheduler state for seamless resumption. Both models were trained with 14 flow iterations per single reconstruction step, so flow_steps_per_recon = 14
We will share the inference evaluation for generation and reconstruction FID and IS soon. Till then enjoy the following results from XL model.
@article{duggal2026unite,
title={End-to-End Training for Unified Tokenization and Latent Denoising},
author={Shivam Duggal and Xingjian Bai and Zongze Wu and Richard Zhang and Eli Shechtman and Antonio Torralba and Phillip Isola and William T. Freeman},
journal={arXiv preprint arXiv:2603.22283},
year={2026}
}



