A minimal, developer-friendly implementation of a Conditional Variational Autoencoder (CVAE) for generating handwritten digits on MNIST. The project includes a Typer-based CLI, optional FastAPI service, MLflow experiment tracking, and reproducible configuration via params.yaml.
- A Variational Autoencoder (VAE) learns a compact latent representation and reconstructs inputs.
- A Conditional VAE (CVAE) lets you control what gets generated by conditioning on a label (e.g., generate a specific digit like "6").
- Python 3.10
- Poetry
- Git
- Optional: CUDA-capable GPU (training is CPU-friendly for MNIST)
git clone https://github.com/moreira-and/cvae-mnist.git
cd cvae-mnist
poetry install
# optional: spawn a shell
poetry shell- Prepare the dataset
poetry run dataset
# saves tensors under data/processed/- Train the model
poetry run train
# common flags: --latent_dim 20 --num_epochs 10 --batch_size 128 --lr 0.001- Generate a digit
poetry run gen --digit 6
# writes reports/figures/cvae_digit6.png- Plot comparison grid
poetry run plot
# writes reports/figures/cvae_comparison.png- End-to-end experiment
poetry run experiment --do-data --do-train --do-plots
# use --no-do-data / --no-do-train / --no-do-plots to skip steps- (Optional) Evaluate reconstructions
poetry run eval --max-batches 5 --sample-ssim 64
# writes reports/metrics.csv- Notebook
- Open
notebooks/quick_start.ipynband Run All.
data/processed/— dataset tensors (train/test)models/— trained weights (model.pth,decoder.pth)reports/figures/— generated images (e.g.,cvae_digit6.png)reports/metrics.csv— evaluation metrics (ifevalis used)mlruns/— local MLflow tracking store
params.yaml— batch size, epochs, latent dimension, learning rate, etc.src/cvae/config.py— paths, device selection, MLflow URI, parameter loading
poetry run mlflow ui --backend-store-uri mlruns --host 0.0.0.0 --port 5000Open http://localhost:5000 and browse runs, parameters, metrics, and artifacts.
dataset— downloads and preprocesses MNIST intodata/processed/.train— trains the CVAE.- Flags:
--latent_dim,--num_epochs,--batch_size,--test_batch_size,--lr,--momentum,--seed.
- Flags:
gen— generates a single digit image.- Flags:
--digit|-d(0–9). Output saved toreports/figures/cvae_digit{d}.png.
- Flags:
plot— creates a comparison grid atreports/figures/cvae_comparison.png.experiment— runsdataset→train→plot.- Flags:
--do-data/--no-do-data,--do-train/--no-do-train,--do-plots/--no-do-plots.
- Flags:
eval— evaluates reconstruction quality on the test split.- Flags:
--max-batches,--sample-ssim,--write-csv/--no-write-csv.
- Flags:
You can also invoke the package with python -m cvae to show Typer help.
A lightweight service is provided in src/cvae/api.py.
- Run locally:
poetry run uvicorn cvae.api:app --host 0.0.0.0 --port 8000 --reload- Useful endpoints:
GET /— HTML with examplesGET /health— service statusPOST /data— prepare datasetPOST /train— train modelPOST /plots— generate comparison plotPOST /gen— generate image for a given digit{ "digit": 7 }POST /experiment— run full pipeline with toggles
cvae-mnist/
- src/
- cvae/
- __main__.py # enables `python -m cvae`
- cli.py # console entry points
- config.py # paths, MLflow, device, params
- api.py # FastAPI service
- service/
- dataset.py # dataset preparation
- train.py # training loop
- gen.py # inference / image generation
- plots.py # visualization utilities
- eval.py # evaluation metrics (BCE/MSE + optional SSIM/PSNR)
- utils.py # data loader, loss, save_model
- models/
- nn_cvae.py
- conditional_encoder.py
- conditional_decoder.py
- notebooks/
- quick_start.ipynb
- models/
- model.pth
- decoder.pth
- reports/
- figures/
- cvae_comparison.png
- cvae_digit5.png
- cvae_digit6.png
- cvae_digit7.png
- performance_analysis.md
- params.yaml
- Makefile
- pyproject.toml
- setup.cfg
- LICENSE
- README.md
After training, the CVAE generates readable digits conditioned on labels. Sample comparison grid:
If results are blurry at first, try more epochs (e.g., --num_epochs 20) or a larger latent_dim.
- Formatting:
black,isort - Linting:
flake8 - Tests:
pytest(add tests undertests/)
Suggested commands:
poetry run black src
poetry run isort src
poetry run flake8 src
poetry run pytestSee CONTRIBUTING.md for guidelines on filing issues, proposing changes, and submitting PRs. All contributions and suggestions are welcome.
MIT — see LICENSE
