Skip to content

zhanglabtools/OTAD

Repository files navigation

OTAD: Optimal Transport-Induced Adversarial Defense

This repository provides the implementation of OTAD on CIFAR-10, demonstrating two defense variants:

  • OTAD-T: Solves the convex integral problem exactly via LP + QCQP (MOSEK solver).
  • OTAD-T-NN: Approximates the CIP solution using a neural network (CIPNet).

Project Structure

OTAD/
├── models/
│   ├── layers.py            # Transformer encoder blocks
│   ├── vit.py               # ViT, ViT_feat, EmbedLayer, Attention
│   ├── models.py            # BasicBlock, CNNBlock, DMLResNet
│   └── cipnet.py            # CIPNet
├── solvers/
│   ├── mosek_potential.py    # LP solver (convex potential)
│   └── mosek_test.py         # QCQP solver (OT map)
├── BPDA.py                   # BPDA wrapper and PGD attack
├── prepare_data.py           # Extract OT data from ViT backbone
├── train_dml.py              # Train DML ResNet for neighbor retrieval
├── prepare_cip_data.py       # Generate CIP training data
├── train_cipnet.py           # Train CIPNet
├── eval_pgd.py               # Evaluate under PGD attack (Linf + L2)
└── eval_autoattack.py        # Evaluate under AutoAttack (Linf + L2)

Requirements

Usage

Step 1: Prepare OT Data

Extract embeddings and encoder outputs from the pretrained ViT backbone:

python prepare_data.py

Step 2: Train DML ResNet

Train the deep metric learning network for neighbor retrieval:

python train_dml.py

Step 3: Generate CIP Training Data

Solve LP + QCQP for training samples to create supervision data:

python prepare_cip_data.py

Step 4: Train CIPNet (for OTAD-T-NN)

Train the neural network to approximate CIP solutions:

python train_cipnet.py

Step 5: Evaluate

BPDA + PGD Attack:

python eval_pgd.py --defense otad-t --gpu 0
python eval_pgd.py --defense otad-t-nn --gpu 0

AutoAttack (OTAD-T-NN only):

python eval_autoattack.py --gpu 0

Pretrained Checkpoints

Place pretrained model weights in ./checkpoints/:

  • vit_cifar10.pth — ViT backbone
  • dml_resnet.pth — DML ResNet (generated by train_dml.py)
  • cipnet.pth — CIPNet (generated by train_cipnet.py)

Citation

If you find this work useful, please cite:

@article{gai2026otad,
      title={OTAD: An Optimal Transport-Induced Robust Model for Agnostic Adversarial Attack}, 
      author={Kuo Gai and Sicong Wang and Shihua Zhang},
      year={2026},
      journal={arXiv preprint arXiv:2408.00329}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages