Skip to content

RushyanthN/ResNet50_VisionTransformers_for_ChestXray_Classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

3 Commits
Β 
Β 
Β 
Β 

Repository files navigation

Vision Transformers vs ResNet50 for Chest X-ray Classification

This project compares Vision Transformers (ViT) and ResNet50 architectures for multi-label classification of chest X-ray images using the NIH Chest X-ray dataset. The study includes class balancing, focal loss, ensemble methods, uncertainty estimation, and model calibration techniques.

πŸ“‹ Table of Contents

πŸ” Overview

This project implements and compares two state-of-the-art deep learning architectures for medical image classification:

  1. Vision Transformer (ViT): A pure transformer-based approach for image classification
  2. ResNet50: A convolutional neural network with residual connections

Both models are trained on the NIH Chest X-ray dataset to classify 15 different medical conditions from chest X-ray images.

πŸ“Š Dataset

  • Source: NIH Chest X-ray dataset
  • Images: 5,606 chest X-ray images
  • Labels: 15 medical conditions including:
    • No Finding, Atelectasis, Cardiomegaly, Consolidation, Edema
    • Effusion, Emphysema, Fibrosis, Hernia, Infiltration
    • Mass, Nodule, Pleural_Thickening, Pneumonia, Pneumothorax
  • Preprocessing: Images resized to 224Γ—224 pixels, grayscale converted to RGB

πŸ—οΈ Architecture Details

Vision Transformer (ViT)

  • Patch Size: 16Γ—16 pixels
  • Embedding Dimension: 256
  • Number of Heads: 4
  • Transformer Blocks: 6
  • MLP Units: [1024, 256]
  • Dropout: 0.3

ResNet50

  • Base Model: Pre-trained ResNet50 on ImageNet
  • Input: 224Γ—224Γ—3 RGB images
  • Global Average Pooling: Applied to feature maps
  • Classification Head: Dense layer with sigmoid activation

✨ Key Features

1. Class Balancing

  • Undersampling: Reduced "No Finding" class by 50%
  • Class Weights: Computed balanced weights for each medical condition
  • Focal Loss: Addressed class imbalance with Ξ± and Ξ³ parameters

2. Advanced Training Techniques

  • Mixed Precision: FP16 training for faster computation
  • Cosine Learning Rate Decay: Smooth learning rate scheduling
  • Early Stopping: Prevented overfitting with validation monitoring
  • Data Augmentation: Rotation, shifting, zooming, and horizontal flipping

3. Model Evaluation

  • Multi-label Metrics: AUROC, AUPRC, and micro-F1 scores
  • Per-class Analysis: Individual performance for each medical condition
  • Ensemble Methods: Combined predictions from multiple models
  • Uncertainty Estimation: Monte Carlo dropout for confidence assessment

4. Model Calibration

  • Temperature Scaling: Improved probability calibration
  • Reliability Diagrams: Visualized calibration quality
  • Binary Cross-Entropy: Measured calibration improvements

πŸ“ˆ Results

Vision Transformer Performance

Model AUROC AUPRC Micro-F1
ViT + Class Weights 0.480 0.130 0.145
ViT Baseline 0.518 0.129 0.000

ResNet50 Performance

Model AUROC AUPRC Micro-F1
ResNet + Class Weights 0.714 0.196 0.365
ResNet Baseline 0.712 0.194 0.231
Ensemble 0.723 0.210 0.322

Key Findings

  • ResNet50 outperformed ViT in overall performance
  • Class weighting improved ResNet50 performance significantly
  • Ensemble methods provided the best overall results
  • Temperature scaling improved model calibration

πŸš€ Setup and Installation

Prerequisites

  • Python 3.7+
  • TensorFlow 2.x
  • Google Colab (recommended for GPU access)

Installation

# Install required packages
pip install tensorflow tensorflow-addons scikit-learn
pip install umap-learn matplotlib opencv-python
pip install transformers pandas numpy

πŸ“Š Model Performance

Per-Class Performance (ResNet50 + Weights)

Condition AUROC AUPRC
No Finding 0.480 0.360
Atelectasis 0.541 0.119
Cardiomegaly 0.575 0.087
Consolidation 0.596 0.055
Edema 0.691 0.238
Effusion 0.440 0.143
Emphysema 0.464 0.052
Fibrosis 0.473 0.005
Hernia NaN 0.000
Infiltration 0.433 0.269
Mass 0.686 0.095
Nodule 0.624 0.224
Pleural_Thickening 0.549 0.075
Pneumonia 0.609 0.028
Pneumothorax 0.572 0.198

πŸ–ΌοΈ Visualization Examples

The project includes several visualization techniques:

1. Training Curves

  • Loss, AUROC, and AUPRC progression over epochs
  • Comparison between weighted and baseline models

2. Attention Maps

  • Vision Transformer attention rollout visualization
  • Grad-CAM saliency maps for ResNet50

3. Prediction Examples

  • Correct and incorrect predictions with ground truth labels
  • Visual comparison between different models

4. Calibration Analysis

  • Reliability diagrams showing probability calibration
  • Before and after temperature scaling comparison

πŸ”§ Technical Details

Loss Functions

  • Focal Loss: FL(p_t) = -Ξ±_t(1-p_t)^Ξ³ log(p_t)
  • Weighted BCE: WBCE = -w[y*log(p) + (1-y)*log(1-p)]

Optimization

  • Optimizer: Adam with cosine decay
  • Learning Rate: 3e-5 initial, 1e-6 final
  • Batch Size: 8
  • Epochs: 10-20 (with early stopping)

Data Augmentation

  • Rotation: Β±15 degrees
  • Translation: Β±10% width/height
  • Zoom: Β±10%
  • Horizontal flip: 50% probability

πŸ“š References

  1. Dosovitskiy, A., et al. "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." ICLR 2021.
  2. He, K., et al. "Deep Residual Learning for Image Recognition." CVPR 2016.
  3. Lin, T. Y., et al. "Focal Loss for Dense Object Detection." ICCV 2017.
  4. Wang, X., et al. "ChestX-ray8: Hospital-scale Chest X-ray Database and Benchmarks on Weakly-Supervised Classification and Localization of Common Thorax Diseases." CVPR 2017.

🀝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors