This repository contains a unified pipeline for distilling a student VMamba model from a RETFound teacher, training classification heads, and evaluating on retinal fundus datasets (IDRiD, APTOS, MBRSET).
This README focuses on the CLI (main.py) and the actual project layout so you can run evaluation and training quickly.
main.py— unified CLI entrypoint (see CLI section below)config/—constants.py(defaults and environment-driven overrides)dataloader/—idrid.py,aptos.py(Lightning DataModules)models/—retfound.py,vmamba_backbone.py,dist.py,models_vit.pytrain/—distill.py,head.py,train_retfound.pyeval/—eval_vmamba.py,eval_retfound.py,shared_eval.pyoptimizers/—optimizer.py(helper for warmup + cosine schedule)utils/— utilities (flops.py,pos_embed.py, etc.)imgs/,results/— example outputs and CSVsrequirements.txt,simple_test.ipynb
Run python main.py --run <mode> [options].
Modes (value for --run):
distill— Phase I: distill RETFound teacher -> VMamba studenthead— Phase II: train classification head on distilled backboneeval— Evaluate VMamba (expects a checkpoint saved by head training)retfound_linear— RETFound linear probe (usestrain_retfound.py)retfound_finetune— RETFound fine-tuningretfound_eval— Evaluate a RETFound Lightning checkpoint
Shared CLI arguments (most common):
--lr— learning rate (default fromconfig/constants.py)--mask_ratio— mask ratio used during distillation--dist_epochs— number of distillation epochs--head_epochs— number of head training epochs--teacher_ckpt— optional teacher checkpoint override--load_backbone— path to distilled backbone (required for--run head)--load_model— path to full model checkpoint for evaluation (required for--run evaland--run retfound_eval)--dataset—idrid(default) oraptos
RETFound-specific:
--checkpoint— RETFound pretrained checkpoint (optional)--epochs— epochs for RETFound linear/finetune modes
The CLI enforces required flags per mode (e.g., --load_backbone for head). See main.py for the full help text.
Evaluation — VMamba (IDRiD):
python main.py --run eval --load_model CHECKPOINT_PATH --dataset idridEvaluation — VMamba (MBRSET):
python main.py --run eval --load_model CHECKPOINT_PATH --dataset mbrsetEvaluation — VMamba (APTOS):
python main.py --run eval --load_model CHECKPOINT_PATH --dataset aptosEvaluation — RETFound Lightning ckpt:
python main.py --run retfound_eval --load_model checkpoints/retfound_finetuned.ckpt --dataset idridDistillation (Phase I):
python main.py --run distill --lr 1e-4 --mask_ratio 0.75Head training (Phase II):
python main.py --run head --load_backbone checkpoints/vmamba_distilled_student.pth --lr 1e-4RETFound linear / finetune:
# Linear probe
python main.py --run retfound_linear --dataset idrid --lr 3e-4
# Fine-tune
python main.py --run retfound_finetune --dataset aptos --lr 1e-5- VMamba head output (
train/head.py) saves a file expected byevalthat contains two keys:backboneandhead(each a state_dict). - RETFound evaluation expects a PyTorch Lightning checkpoint compatible with
train/train_retfound.RETFoundTask.