Skip to content

lingchm/MMFL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Multi-Modal Fission Learning (MMFL)

Overview

MMFL (Multi-Modal Fission Learning) is a machine learning framework designed to handle multi-modal datasets with missing modalities. The framework employs a structural decomposition approach, separating multi-modal data into globally joint, locally joint, and individual components using supervision from labels. This enables effective learning even when some modalities are missing for certain samples, without requiring data imputation.

Key Features

  • Missing data robustness: Gracefully handles incomplete modalities without requiring imputation
  • Structural decomposition: Decomposes data into joint (global/local) and individual components using structured factorization
  • Supervised learning: Jointly learns prediction model and latent components
  • Automatic rank selection: Built-in rank selection heuristics

Installation

Requirements

  • Python 3.8+
  • R 3.6+ (optional, for JIVE and SLIDE R package integration)

Quick Start

  1. Clone the repository:
git clone https://github.com/yourusername/MMFL.git
cd MMFL
  1. Install Python dependencies:
pip install -r requirements.txt
  1. (Optional) Install R packages for baseline comparisons with JIVE/SJIVE/SLIDE:
# In R console or R command line
install.packages("devtools")
devtools::install_github("irinagain/SLIDE")
devtools::install_github("lockEF/r.jive")

Note: R integration requires rpy2 to be properly configured. See rpy2 documentation for troubleshooting.

Project Structure

MMFL/
├── models/                    # Core model implementations
│   ├── MMFL.py               # Main MMFL algorithm with rank selection
│   ├── MADDi.py              # Multi-modal Attention-based Deep Learning
│   ├── IMLS.py               # Incomplete Multi-modality Latent Space learning
│   └── stagewise.py          # Stagewise deep learning models
├── utils/                     # Utility functions
│   ├── train.py              # Model training and evaluation functions
│   ├── metrics.py            # Evaluation metrics (AUC, accuracy, etc.)
│   ├── prepare_dataset.py    # Data preprocessing and splitting
│   ├── generate_simulation.py # Synthetic data generation
│   ├── visualization.py      # Plotting utilities
│   ├── oversampling.py       # SMOTE oversampling for imbalanced data
│   ├── rank_selection.py     # Rank selection utilities
│   └── compare_auc_delong_xu.py # Statistical comparison methods
├── experiments/               # Experimental notebooks
│   ├── case_study_adni.ipynb # ADNI dataset experiments
│   ├── case_study_headache.ipynb # Headache dataset experiments
│   └── simulation_study.ipynb # Simulation studies
├── preprocessing/             # Data preprocessing scripts
├── data/                      # Data files
│   ├── ADNI_dataset.csv      # ADNI dataset
│   ├── ADNI_SNP_fisher_nature_p0.0005.csv # SNP data
│   └── headache_*.csv        # Headache study data
├── results/                   # Experimental results (JSON format)
└── archive/                   # Legacy code and data

Usage

Basic Usage Example

from models.MMFL import MMFL
from utils.prepare_dataset import train_test_split
import numpy as np
import pandas as pd

# Load your multi-modal data
data = pd.read_csv("data/your_dataset.csv")

# Prepare datasets
X_train, X_test, y_train, y_test, mask_train, mask_test, covariates_train, covariates_test = train_test_split(
    data, 
    modalities=['MRI', 'FDG', 'SNP'], 
    train_prop=0.8,
    normalize=True
)

# Initialize and train MMFL model with automatic rank selection
model = MMFL(
    rs="auto",  # Automatic rank selection
    lam=1.0,    # Weight for predictive loss
    gam=0.1,    # L2 regularization on coefficients
    mu=0.01,    # Regularization parameter
    rank_selection_criterion="auc",  # Use AUC for rank selection
    verbose=True
)

# Train the model
model.fit(X_train, y_train, covariates_train, max_iter=100)

# Make predictions
predictions = model.predict(X_test, covariates_test)

# Evaluate performance
from utils.metrics import evaluate_y
metrics = evaluate_y(y_test, predictions, binary=True)
print(f"AUC: {metrics['AUC']:.3f}")
print(f"Accuracy: {metrics['Accuracy']:.3f}")

Running Experiments

1. Simulation Study

Reproduce simulation experiments to validate the method:

jupyter notebook experiments/simulation_study.ipynb

2. Case Study - ADNI Dataset

Run experiments on the Alzheimer's Disease Neuroimaging Initiative dataset:

jupyter notebook experiments/case_study_adni.ipynb

3. Case Study - Headache Dataset

Run experiments on the headache study dataset:

jupyter notebook experiments/case_study_headache.ipynb

Models

The package includes several models for multi-modal learning:

  • MMFL (Multi-Modal Fission Learning) - Main supervised decomposition model
  • JIVE/SJIVE - Joint and Individual Variance Explained
  • SLIDE - Sparse Linear Identifiable VAE
  • IMLS (Incomplete Multi-modality Latent Space) - Latent space learning
  • Stagewise Deep Learning - Multi-stage neural network
  • MADDi (Multi-modal Attention-based Deep Learning) - Attention-based fusion

Citation

If you use this code in your research, please cite:

@article{mao2024supervised,
  title={Supervised multi-modal fission learning},
  author={Mao, Lingchao and Su, Yi and Lure, Fleming and Li, Jing and others},
  journal={arXiv preprint arXiv:2409.20559},
  year={2024}
}

Acknowledgments

  • ADNI: Alzheimer's Disease Neuroimaging Initiative for providing the dataset (data available upon request from ADNI)
  • R Package Contributors: Thanks to the maintainers of r.jive and SLIDE R packages
  • Open Source Libraries: scikit-learn, PyTorch, pandas, numpy, and the scientific Python community

About

Multi-Modal Fission Learning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors