This guide explains how to use the flow matching components that have been added to this codebase.
Flow matching is a generative modeling approach that learns to transform noise into data through velocity field prediction. It's conceptually simpler than diffusion models while achieving comparable quality.
Train a flow matching model on CIFAR-10:
python train_flow_matching.pyThis will:
- Load CIFAR-10 dataset automatically
- Train a DiT2D model with flow matching
- Save checkpoints every 5000 steps to
checkpoints/ - Generate sample images every 1000 steps to
samples/ - Use mixed precision (bfloat16) for efficient training
- Apply Exponential Moving Average (EMA) for stable sampling
After training, generate images:
# Generate 16 random samples
python generate_flow_matching.py --checkpoint checkpoints/flow_matching_final.pt --num_samples 16
# Generate samples for a specific class
python generate_flow_matching.py --checkpoint checkpoints/flow_matching_final.pt --num_samples 16 --class_label 5
# Generate a grid showing all classes
python generate_flow_matching.py --checkpoint checkpoints/flow_matching_final.pt --class_grid --samples_per_class 8
# Adjust sampling quality/speed
python generate_flow_matching.py --checkpoint checkpoints/flow_matching_final.pt --num_samples 16 --num_steps 100 --cfg_scale 4.0The model is based on the DiT architecture adapted for 2D images:
- Patch Embedding: 2x2 patches convert images to tokens
- Positional Encoding: 2D sinusoidal position embeddings
- Transformer Blocks: Standard attention + MLP with AdaLN conditioning
- Time Conditioning: Sinusoidal timestep embeddings fed through MLP
- Class Conditioning: Optional label embeddings for conditional generation
- Output: Predicts velocity field v(x_t, t)
Edit configs/flow_matching_config.py to customize:
@dataclass
class FlowMatchingConfig:
# Model architecture
hidden_size: int = 768 # Model width
depth: int = 12 # Number of layers
num_heads: int = 12 # Attention heads
# Training
batch_size: int = 128
lr: float = 1e-4
train_steps: int = 100000
# Sampling
num_sampling_steps: int = 50 # Quality/speed tradeoff
cfg_scale: float = 3.0 # Guidance strength- Sample timestep: t ~ Uniform(0, 1)
- Create interpolation: x_t = (1-t) * noise + t * data
- Compute target: v_target = data - noise
- Predict velocity: v_pred = model(x_t, t, class_label)
- Loss: MSE(v_pred, v_target)
- Start from noise: x_0 ~ N(0, I)
- Euler integration: For t = 0 to 1 with step dt:
- Predict v_t = model(x_t, t, class_label)
- Update: x_{t+dt} = x_t + v_t * dt
- Result: x_1 is the generated image
During training, class labels are randomly dropped 10% of the time. During sampling:
v_guided = v_uncond + cfg_scale * (v_cond - v_uncond)
Higher cfg_scale (e.g., 4.0) gives more class-specific but potentially less diverse samples.
The training script maintains an EMA of model parameters with decay 0.9999. The EMA model is used for sampling, which typically produces higher quality results.
Uses PyTorch AMP with bfloat16 for faster training and lower memory usage:
- ~2x speedup on modern GPUs
- Maintains numerical stability through gradient scaling
The following components from the original codebase were reused:
- VQ-VAE (
models/vqvae.py): Can be used for latent flow matching - DiT Architecture (
models/dit.py,models/layers.py): Adapted to 2D - Training Infrastructure (
training/trainer.py): Mixed precision, optimizers - Data Pipeline (
data/vqvae_dataset.py): CIFAR-10 loading - Optimizer (
optimizers/muon.py): Can replace AdamW if desired
To train in VQ-VAE latent space instead of pixel space:
- Train a VQ-VAE first:
python train_vqvae.py- Load the VQ-VAE encoder in the dataset:
from models.vqvae import VQVAE
vqvae = VQVAE.load_from_checkpoint("checkpoints/vqvae.pt")
encoder = vqvae.encoder
dataloader = get_flow_matching_dataloader(
vqvae_encoder=encoder,
...
)Replace CIFAR-10 with your own dataset:
class CustomFlowMatchingDataset(Dataset):
def __getitem__(self, idx):
image = load_your_image(idx)
label = load_your_label(idx)
image = preprocess(image) # Normalize to [-1, 1]
return image, labelFor larger models:
- Increase
hidden_size,depth,num_headsin config - Use gradient accumulation by reducing
micro_batch_size - Consider using the Muon optimizer instead of AdamW
- Enable gradient checkpointing to save memory
image-gen/
├── configs/
│ └── flow_matching_config.py # Configuration
├── models/
│ ├── dit_2d.py # 2D DiT model
│ └── layers.py # Building blocks
├── data/
│ └── flow_matching_dataset.py # Dataset loaders
├── train_flow_matching.py # Training script
├── generate_flow_matching.py # Inference script
└── FLOW_MATCHING_GUIDE.md # This file
- Training Duration: 50k-100k steps usually sufficient for CIFAR-10
- Sampling Steps: 50 steps is a good balance, 100+ for best quality
- CFG Scale: Try 2.0-4.0, higher = more adherence to class labels
- Batch Size: Larger is better (up to memory limits)
- Learning Rate: 1e-4 works well, use warmup for stability
- EMA: Always use EMA for sampling, it significantly improves quality
Q: Loss is not decreasing
- Check learning rate (try 1e-4 to 5e-4)
- Ensure images are normalized to [-1, 1]
- Verify dataset is loading correctly
Q: Generated images are blurry
- Increase number of sampling steps
- Use EMA model for sampling
- Train for more steps
Q: Out of memory
- Reduce
micro_batch_size - Reduce model size (
hidden_size,depth) - Use gradient checkpointing
Q: Training is slow
- Ensure CUDA is available
- Use smaller
num_workersif CPU is bottleneck - Consider reducing image resolution
- Flow Matching: Lipman et al., 2023
- DiT: Peebles & Xie, 2023
- Classifier-Free Guidance: Ho & Salimans, 2022
Same as the main project.