Skip to content

moreira-and/cvae-mnist

Repository files navigation

C-VAE MNIST

A minimal, developer-friendly implementation of a Conditional Variational Autoencoder (CVAE) for generating handwritten digits on MNIST. The project includes a Typer-based CLI, optional FastAPI service, MLflow experiment tracking, and reproducible configuration via params.yaml.


What's a CVAE?

  • A Variational Autoencoder (VAE) learns a compact latent representation and reconstructs inputs.
  • A Conditional VAE (CVAE) lets you control what gets generated by conditioning on a label (e.g., generate a specific digit like "6").

Requirements

  • Python 3.10
  • Poetry
  • Git
  • Optional: CUDA-capable GPU (training is CPU-friendly for MNIST)

Installation

git clone https://github.com/moreira-and/cvae-mnist.git
cd cvae-mnist
poetry install
# optional: spawn a shell
poetry shell

Quick Start

  1. Prepare the dataset
poetry run dataset
# saves tensors under data/processed/
  1. Train the model
poetry run train
# common flags: --latent_dim 20 --num_epochs 10 --batch_size 128 --lr 0.001
  1. Generate a digit
poetry run gen --digit 6
# writes reports/figures/cvae_digit6.png
  1. Plot comparison grid
poetry run plot
# writes reports/figures/cvae_comparison.png
  1. End-to-end experiment
poetry run experiment --do-data --do-train --do-plots
# use --no-do-data / --no-do-train / --no-do-plots to skip steps
  1. (Optional) Evaluate reconstructions
poetry run eval --max-batches 5 --sample-ssim 64
# writes reports/metrics.csv
  1. Notebook

Where outputs are saved

  • data/processed/ — dataset tensors (train/test)
  • models/ — trained weights (model.pth, decoder.pth)
  • reports/figures/ — generated images (e.g., cvae_digit6.png)
  • reports/metrics.csv — evaluation metrics (if eval is used)
  • mlruns/ — local MLflow tracking store

Configuration

  • params.yaml — batch size, epochs, latent dimension, learning rate, etc.
  • src/cvae/config.py — paths, device selection, MLflow URI, parameter loading

MLflow UI (local)

poetry run mlflow ui --backend-store-uri mlruns --host 0.0.0.0 --port 5000

Open http://localhost:5000 and browse runs, parameters, metrics, and artifacts.


CLI Reference

  • dataset — downloads and preprocesses MNIST into data/processed/.
  • train — trains the CVAE.
    • Flags: --latent_dim, --num_epochs, --batch_size, --test_batch_size, --lr, --momentum, --seed.
  • gen — generates a single digit image.
    • Flags: --digit|-d (0–9). Output saved to reports/figures/cvae_digit{d}.png.
  • plot — creates a comparison grid at reports/figures/cvae_comparison.png.
  • experiment — runs datasettrainplot.
    • Flags: --do-data/--no-do-data, --do-train/--no-do-train, --do-plots/--no-do-plots.
  • eval — evaluates reconstruction quality on the test split.
    • Flags: --max-batches, --sample-ssim, --write-csv/--no-write-csv.

You can also invoke the package with python -m cvae to show Typer help.


API (FastAPI)

A lightweight service is provided in src/cvae/api.py.

  • Run locally:
poetry run uvicorn cvae.api:app --host 0.0.0.0 --port 8000 --reload
  • Useful endpoints:
    • GET / — HTML with examples
    • GET /health — service status
    • POST /data — prepare dataset
    • POST /train — train model
    • POST /plots — generate comparison plot
    • POST /gen — generate image for a given digit { "digit": 7 }
    • POST /experiment — run full pipeline with toggles

Project Layout

cvae-mnist/
  - src/
    - cvae/
      - __main__.py        # enables `python -m cvae`
      - cli.py             # console entry points
      - config.py          # paths, MLflow, device, params
      - api.py             # FastAPI service
      - service/
        - dataset.py       # dataset preparation
        - train.py         # training loop
        - gen.py           # inference / image generation
        - plots.py         # visualization utilities
        - eval.py          # evaluation metrics (BCE/MSE + optional SSIM/PSNR)
        - utils.py         # data loader, loss, save_model
        - models/
          - nn_cvae.py
          - conditional_encoder.py
          - conditional_decoder.py
  - notebooks/
    - quick_start.ipynb
  - models/
    - model.pth
    - decoder.pth
  - reports/
    - figures/
      - cvae_comparison.png
      - cvae_digit5.png
      - cvae_digit6.png
      - cvae_digit7.png
    - performance_analysis.md
  - params.yaml
  - Makefile
  - pyproject.toml
  - setup.cfg
  - LICENSE
  - README.md

Results

After training, the CVAE generates readable digits conditioned on labels. Sample comparison grid:

cvae_comparison

If results are blurry at first, try more epochs (e.g., --num_epochs 20) or a larger latent_dim.


Development

  • Formatting: black, isort
  • Linting: flake8
  • Tests: pytest (add tests under tests/)

Suggested commands:

poetry run black src
poetry run isort src
poetry run flake8 src
poetry run pytest

Contributing

See CONTRIBUTING.md for guidelines on filing issues, proposing changes, and submitting PRs. All contributions and suggestions are welcome.


License

MIT — see LICENSE

About

The project aims to explore the potential of Conditional Variational Autoencoders (C-VAEs) for generating synthetic data from specific image classes in the MNIST dataset.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Contributors