REGINA: Regularized Encoder with Latent Cycle-GAN for In-vitro Neural Cell Perturbation Approximation.
In this github we introduce every material needed for our REGINA research. This is a virtuall cell modelling pipeline with generative AI approaches centered around Cycle-GAN workflow.
For the easies possible reproduction we included the singularity file we used to work with the pipeline. To be able to use it first you need to install singularity link.
To install REGINA run:
singularity build --fix-perms REGINA.sif REGINA.def
Data is downloaded via GEARS link:
from gears import PertData, GEARS
# get data
dataset_name = 'norman'
pert_data = PertData('./data')
# load dataset in paper: norman, adamson, dixit.
pert_data.load(data_name = f'{dataset_name}',data_path=None)
Our methods require to split the data into train-validation-test split. You can either use GEARS custom splitting method, which we used or use the split dict we got, and included.
Splitting the data:
train_adata = adata[adata.obs[perturbation_key].isin(custom_split_dict['train'])]
train_adata.write_h5ad("data/{dataset_name}/train.h5ad")
val_adata = adata[adata.obs[perturbation_key].isin(custom_split_dict['val'])]
val_adata.write_h5ad("data/{dataset_name}/val.h5ad")
test_adata = adata[adata.obs[perturbation_key].isin(custom_split_dict['test'])]
test_adata.write_h5ad("data/{dataset_name}/test.h5ad")
To use REGINA latent classifier you can add any given method to generate class information. The methods we used for dixit, norman, adamson dataset is:
def add_binary_state(adata):
adata.var_names = adata.var.gene_name
stress_prefixes = ('HSP', 'ATF', 'DNAJ', 'ERN', 'EIF2', 'CEBP')
available_genes = adata.var_names.tolist()
valid_markers = [g for g in available_genes if g.startswith(stress_prefixes)]
sc.tl.score_genes(adata, gene_list=valid_markers, score_name='stress_score')
threshold = adata.obs['stress_score'].quantile(0.70)
adata.obs['cell_state'] = 'Homeostasis'
adata.obs.loc[adata.obs['stress_score'] > threshold, 'cell_state'] = 'Stressed'
print("\nFinal State Distribution:")
print(adata.obs['cell_state'].value_counts())
return adata
train_adata = ad.read_h5ad(f"data/{dataset_name}/train.h5ad")
val_adata = ad.read_h5ad(f"data/{dataset_name}/val.h5ad")
test_adata = ad.read_h5ad(f"data/{dataset_name}/test.h5ad")
train_adata = add_binary_state(train_adata)
val_adata = add_binary_state(val_adata)
test_adata = add_binary_state(test_adata)
train_adata.write_h5ad(f"data/{dataset_name}/train_processed.h5ad")
val_adata.write_h5ad(f"data/{dataset_name}/val_processed.h5ad")
test_adata.write_h5ad(f"data/{dataset_name}/test_processed.h5ad")
gene_to_idx = { gene:i for i, gene in enumerate(train_adata.var_names) }
import pickle
with open("data/{dataset_name}/gene_to_idx.pkl", "wb") as f:
pickle.dump(gene_to_idx, f)
After the preprocessing you can simply run the train.sh from bash or:
singularity exec --nv REGINA.sif python3 train_models.py
In the end of the file you can modify which dataset you want to train on. Please ignore the rest of the models!
For evlauation you can call the eval.sh or:
singularity exec --nv REGINA.sif python3 evaluate_h5ad.py
Figure 1: Training pipeline of the regularized autoencoder.
Our approach follows the general paradigm of latent diffusion models, in which high-dimensional observations are first mapped into a lower-dimensional latent space via an autoencoder.
Given the high dimensionality of gene expression profiles, processing the full vector with standard linear layers is computationally heavy. To address this, we tokenize the input vector into a sequence of lower-dimensional segments. These segments are then processed by a Transformer encoder, utilizing bidirectional self-attention to capture complex, non-local gene dependencies.
Let
A standard autoencoder does not explicitly enforce preservation of biologically meaningful class information, such as the cellular state. To be able to study perturbation effects in the latent space, we augment the model with a latent classifier
The classifier is trained using a stop-gradient operation on the encoder output, preventing its updates from affecting the encoder:
where
Training a vanilla autoencoder typically results in an unregularized latent space that is unsuitable for downstream perturbation modeling.
To address this, we introduce a center loss (
To additionally regularize the encoder, we include a second classification term without the stop-gradient:
This term encourages the encoder to produce latents that are themselves predictive of the cell state, while the classifier receives gradients from both terms (which only improves its accuracy and does not harm training dynamics).
The full autoencoder objective is then:
where
To ensure training stability and a self consistent encoder and decoder we applied consistency monitoring. A latent sample
where
Figure 2: Training pipeline of the second phase of training.
Due to the absence of paired control–perturbation samples, a supervised latent transition model cannot be trained directly. We therefore adopt a latent cycle GAN framework that learns bidirectional mappings between control latent distributions
A forward transition block $T^{fwd}{\Theta_1}$ maps control latents to perturbed latents, while a backward transition block $T^{bwd}{\Theta_2}$ performs the inverse mapping:
Since the true inverse of a biological perturbation is unknown, the backward transition model is tasked with learning an implicit inversion corresponding to the perturbation that generated the observed perturbed state. Converting perturbed state to control state has no biological meaning, it was just a proxy task to increase the quality of our data generation.
To ensure consistency between these transformations, we apply a cycle-consistency loss
Without additional constraints, the transition model may collapse to a trivial identity mapping, i.e.,
where
To address unseen perturbations during training, we applied a latent prompting that explicitly encodes perturbation information. Given a perturbation index
The transition model is conditioned on this prompt as:
The same prompt is used for both forward and backward transformations.