This project compares Vision Transformers (ViT) and ResNet50 architectures for multi-label classification of chest X-ray images using the NIH Chest X-ray dataset. The study includes class balancing, focal loss, ensemble methods, uncertainty estimation, and model calibration techniques.
- Overview
- Dataset
- Architecture Details
- Key Features
- Results
- Setup and Installation
- Model Performance
- Visualization Examples
- Technical Details
This project implements and compares two state-of-the-art deep learning architectures for medical image classification:
- Vision Transformer (ViT): A pure transformer-based approach for image classification
- ResNet50: A convolutional neural network with residual connections
Both models are trained on the NIH Chest X-ray dataset to classify 15 different medical conditions from chest X-ray images.
- Source: NIH Chest X-ray dataset
- Images: 5,606 chest X-ray images
- Labels: 15 medical conditions including:
- No Finding, Atelectasis, Cardiomegaly, Consolidation, Edema
- Effusion, Emphysema, Fibrosis, Hernia, Infiltration
- Mass, Nodule, Pleural_Thickening, Pneumonia, Pneumothorax
- Preprocessing: Images resized to 224Γ224 pixels, grayscale converted to RGB
- Patch Size: 16Γ16 pixels
- Embedding Dimension: 256
- Number of Heads: 4
- Transformer Blocks: 6
- MLP Units: [1024, 256]
- Dropout: 0.3
- Base Model: Pre-trained ResNet50 on ImageNet
- Input: 224Γ224Γ3 RGB images
- Global Average Pooling: Applied to feature maps
- Classification Head: Dense layer with sigmoid activation
- Undersampling: Reduced "No Finding" class by 50%
- Class Weights: Computed balanced weights for each medical condition
- Focal Loss: Addressed class imbalance with Ξ± and Ξ³ parameters
- Mixed Precision: FP16 training for faster computation
- Cosine Learning Rate Decay: Smooth learning rate scheduling
- Early Stopping: Prevented overfitting with validation monitoring
- Data Augmentation: Rotation, shifting, zooming, and horizontal flipping
- Multi-label Metrics: AUROC, AUPRC, and micro-F1 scores
- Per-class Analysis: Individual performance for each medical condition
- Ensemble Methods: Combined predictions from multiple models
- Uncertainty Estimation: Monte Carlo dropout for confidence assessment
- Temperature Scaling: Improved probability calibration
- Reliability Diagrams: Visualized calibration quality
- Binary Cross-Entropy: Measured calibration improvements
| Model | AUROC | AUPRC | Micro-F1 |
|---|---|---|---|
| ViT + Class Weights | 0.480 | 0.130 | 0.145 |
| ViT Baseline | 0.518 | 0.129 | 0.000 |
| Model | AUROC | AUPRC | Micro-F1 |
|---|---|---|---|
| ResNet + Class Weights | 0.714 | 0.196 | 0.365 |
| ResNet Baseline | 0.712 | 0.194 | 0.231 |
| Ensemble | 0.723 | 0.210 | 0.322 |
- ResNet50 outperformed ViT in overall performance
- Class weighting improved ResNet50 performance significantly
- Ensemble methods provided the best overall results
- Temperature scaling improved model calibration
- Python 3.7+
- TensorFlow 2.x
- Google Colab (recommended for GPU access)
# Install required packages
pip install tensorflow tensorflow-addons scikit-learn
pip install umap-learn matplotlib opencv-python
pip install transformers pandas numpy| Condition | AUROC | AUPRC |
|---|---|---|
| No Finding | 0.480 | 0.360 |
| Atelectasis | 0.541 | 0.119 |
| Cardiomegaly | 0.575 | 0.087 |
| Consolidation | 0.596 | 0.055 |
| Edema | 0.691 | 0.238 |
| Effusion | 0.440 | 0.143 |
| Emphysema | 0.464 | 0.052 |
| Fibrosis | 0.473 | 0.005 |
| Hernia | NaN | 0.000 |
| Infiltration | 0.433 | 0.269 |
| Mass | 0.686 | 0.095 |
| Nodule | 0.624 | 0.224 |
| Pleural_Thickening | 0.549 | 0.075 |
| Pneumonia | 0.609 | 0.028 |
| Pneumothorax | 0.572 | 0.198 |
The project includes several visualization techniques:
- Loss, AUROC, and AUPRC progression over epochs
- Comparison between weighted and baseline models
- Vision Transformer attention rollout visualization
- Grad-CAM saliency maps for ResNet50
- Correct and incorrect predictions with ground truth labels
- Visual comparison between different models
- Reliability diagrams showing probability calibration
- Before and after temperature scaling comparison
- Focal Loss:
FL(p_t) = -Ξ±_t(1-p_t)^Ξ³ log(p_t) - Weighted BCE:
WBCE = -w[y*log(p) + (1-y)*log(1-p)]
- Optimizer: Adam with cosine decay
- Learning Rate: 3e-5 initial, 1e-6 final
- Batch Size: 8
- Epochs: 10-20 (with early stopping)
- Rotation: Β±15 degrees
- Translation: Β±10% width/height
- Zoom: Β±10%
- Horizontal flip: 50% probability
- Dosovitskiy, A., et al. "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." ICLR 2021.
- He, K., et al. "Deep Residual Learning for Image Recognition." CVPR 2016.
- Lin, T. Y., et al. "Focal Loss for Dense Object Detection." ICCV 2017.
- Wang, X., et al. "ChestX-ray8: Hospital-scale Chest X-ray Database and Benchmarks on Weakly-Supervised Classification and Localization of Common Thorax Diseases." CVPR 2017.
Contributions are welcome! Please feel free to submit a Pull Request.