MMFL (Multi-Modal Fission Learning) is a machine learning framework designed to handle multi-modal datasets with missing modalities. The framework employs a structural decomposition approach, separating multi-modal data into globally joint, locally joint, and individual components using supervision from labels. This enables effective learning even when some modalities are missing for certain samples, without requiring data imputation.
- Missing data robustness: Gracefully handles incomplete modalities without requiring imputation
- Structural decomposition: Decomposes data into joint (global/local) and individual components using structured factorization
- Supervised learning: Jointly learns prediction model and latent components
- Automatic rank selection: Built-in rank selection heuristics
- Python 3.8+
- R 3.6+ (optional, for JIVE and SLIDE R package integration)
- Clone the repository:
git clone https://github.com/yourusername/MMFL.git
cd MMFL- Install Python dependencies:
pip install -r requirements.txt- (Optional) Install R packages for baseline comparisons with JIVE/SJIVE/SLIDE:
# In R console or R command line
install.packages("devtools")
devtools::install_github("irinagain/SLIDE")
devtools::install_github("lockEF/r.jive")Note: R integration requires rpy2 to be properly configured. See rpy2 documentation for troubleshooting.
MMFL/
├── models/ # Core model implementations
│ ├── MMFL.py # Main MMFL algorithm with rank selection
│ ├── MADDi.py # Multi-modal Attention-based Deep Learning
│ ├── IMLS.py # Incomplete Multi-modality Latent Space learning
│ └── stagewise.py # Stagewise deep learning models
├── utils/ # Utility functions
│ ├── train.py # Model training and evaluation functions
│ ├── metrics.py # Evaluation metrics (AUC, accuracy, etc.)
│ ├── prepare_dataset.py # Data preprocessing and splitting
│ ├── generate_simulation.py # Synthetic data generation
│ ├── visualization.py # Plotting utilities
│ ├── oversampling.py # SMOTE oversampling for imbalanced data
│ ├── rank_selection.py # Rank selection utilities
│ └── compare_auc_delong_xu.py # Statistical comparison methods
├── experiments/ # Experimental notebooks
│ ├── case_study_adni.ipynb # ADNI dataset experiments
│ ├── case_study_headache.ipynb # Headache dataset experiments
│ └── simulation_study.ipynb # Simulation studies
├── preprocessing/ # Data preprocessing scripts
├── data/ # Data files
│ ├── ADNI_dataset.csv # ADNI dataset
│ ├── ADNI_SNP_fisher_nature_p0.0005.csv # SNP data
│ └── headache_*.csv # Headache study data
├── results/ # Experimental results (JSON format)
└── archive/ # Legacy code and data
from models.MMFL import MMFL
from utils.prepare_dataset import train_test_split
import numpy as np
import pandas as pd
# Load your multi-modal data
data = pd.read_csv("data/your_dataset.csv")
# Prepare datasets
X_train, X_test, y_train, y_test, mask_train, mask_test, covariates_train, covariates_test = train_test_split(
data,
modalities=['MRI', 'FDG', 'SNP'],
train_prop=0.8,
normalize=True
)
# Initialize and train MMFL model with automatic rank selection
model = MMFL(
rs="auto", # Automatic rank selection
lam=1.0, # Weight for predictive loss
gam=0.1, # L2 regularization on coefficients
mu=0.01, # Regularization parameter
rank_selection_criterion="auc", # Use AUC for rank selection
verbose=True
)
# Train the model
model.fit(X_train, y_train, covariates_train, max_iter=100)
# Make predictions
predictions = model.predict(X_test, covariates_test)
# Evaluate performance
from utils.metrics import evaluate_y
metrics = evaluate_y(y_test, predictions, binary=True)
print(f"AUC: {metrics['AUC']:.3f}")
print(f"Accuracy: {metrics['Accuracy']:.3f}")Reproduce simulation experiments to validate the method:
jupyter notebook experiments/simulation_study.ipynbRun experiments on the Alzheimer's Disease Neuroimaging Initiative dataset:
jupyter notebook experiments/case_study_adni.ipynbRun experiments on the headache study dataset:
jupyter notebook experiments/case_study_headache.ipynbThe package includes several models for multi-modal learning:
- MMFL (Multi-Modal Fission Learning) - Main supervised decomposition model
- JIVE/SJIVE - Joint and Individual Variance Explained
- SLIDE - Sparse Linear Identifiable VAE
- IMLS (Incomplete Multi-modality Latent Space) - Latent space learning
- Stagewise Deep Learning - Multi-stage neural network
- MADDi (Multi-modal Attention-based Deep Learning) - Attention-based fusion
If you use this code in your research, please cite:
@article{mao2024supervised,
title={Supervised multi-modal fission learning},
author={Mao, Lingchao and Su, Yi and Lure, Fleming and Li, Jing and others},
journal={arXiv preprint arXiv:2409.20559},
year={2024}
}- ADNI: Alzheimer's Disease Neuroimaging Initiative for providing the dataset (data available upon request from ADNI)
- R Package Contributors: Thanks to the maintainers of
r.jiveandSLIDER packages - Open Source Libraries: scikit-learn, PyTorch, pandas, numpy, and the scientific Python community