-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_run_tier2.py
More file actions
347 lines (299 loc) · 14.9 KB
/
Copy path_run_tier2.py
File metadata and controls
347 lines (299 loc) · 14.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
"""Tier-2 benchmarks: BINN architecture ablations.
Set FEATURE_SET at the top to switch between input modalities:
"rna" — RNA genes filtered to Reactome (9,660 features)
"prot" — Surface proteins resolved to HGNC, filtered to Reactome (~164 features)
"mix" — RNA + protein combined (protein HGNC names merged with matching RNA genes)
Each FEATURE_SET runs three benchmarks:
1. BINN GCN with Reactome pathway mask
2. GCN unconstrained (no mask, same features)
3. Protein-only GCN reference (always, for comparison)
"""
import gc, sys
from pathlib import Path
import anndata as ad
import muon as mu
import numpy as np
import pandas as pd
from scipy.sparse import issparse
from sklearn.decomposition import PCA
sys.path.insert(0, str(Path(__file__).parent))
from binn.data import get_map, prepareAnnData, anndata_to_graph_data, build_reactome_network
from binn.data.protein_synonyms import resolve_protein_name
from binn.learn import Hyperparameters
from binn.train import binn_donor_cv
# ── Configuration ─────────────────────────────────────────────────────────────
FEATURE_SET = "mix" # "rna" | "prot" | "mix"
SAMPLE_SIZE = None # None = all cells;
RANDOM_STATE = 42
N_LEVELS = 3
RNA_MOD = "rna"
PROT_MOD = "prot"
LINEAGE_COL = "manual.coarse.idents"
LINEAGE_MAP = {"nk": ["NK"], "t": ["CD4 T", "CD8 T"], "b": ["B", "PB"]}
NAMES = {k: "GCN" for k in LINEAGE_MAP}
NAMES_GAT = {k: "GAT" for k in LINEAGE_MAP}
NAMES_ANN = {k: "ANN" for k in LINEAGE_MAP}
RUN_BENCHMARKS = {1,5}
N_TRIALS = 3
KNN = 5
# small K for space on memory. v3: KNN=10
#
CV_KWARGS = dict(
embed_dim=32, fusion_hidden=128, fusion_dropout=0.3,
use_class_weights=True, val_size=0.2, n_splits=5, random_state=RANDOM_STATE,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def to_dense(ad_obj, dtype=np.float32):
x = ad_obj.X.toarray() if issparse(ad_obj.X) else np.asarray(ad_obj.X)
return pd.DataFrame(x.astype(dtype), index=ad_obj.obs_names, columns=ad_obj.var_names)
def dedup_cols(df):
"""Column-mean over duplicate names."""
return df if not df.columns.duplicated().any() else df.T.groupby(level=0).mean().T
def run_trials(n_trials, **binn_kwargs):
"""Run binn_donor_cv n_trials times with incrementing random seeds; return averaged metrics."""
accs, aucs = [], []
for t in range(n_trials):
cv_kw = {**CV_KWARGS, "random_state": RANDOM_STATE + t}
r = binn_donor_cv(**binn_kwargs, **cv_kw)
accs.append(r["mean_acc"])
aucs.append(r["mean_auc"])
print(f" trial {t+1}/{n_trials} acc={r['mean_acc']:.3f} auc={r['mean_auc']:.3f}")
return {
"mean_acc": float(np.mean(accs)),
"std_acc": float(np.std(accs)),
"mean_auc": float(np.mean(aucs)),
"std_auc": float(np.std(aucs)),
}
def build_feature_matrix(rna_ad, prot_ad, reactome_genes, feature_set):
"""
Return (merged_df, obs_keep) for the chosen feature_set.
merged_df : DataFrame — cells × (features + obs columns)
obs_keep : list[str] — column names that are metadata, not features
"""
obs_df = rna_ad.obs.copy()
obs_keep = obs_df.columns.tolist()
if feature_set == "rna":
rna_names_up = [str(c).upper() for c in rna_ad.var_names]
keep = [i for i, n in enumerate(rna_names_up) if n in reactome_genes]
sub = rna_ad[:, keep].copy()
feat_df = to_dense(sub)
feat_df.columns = [str(c).upper() for c in feat_df.columns]
elif feature_set == "prot":
prot_names_hgnc = [resolve_protein_name(n) for n in prot_ad.var_names]
keep = [i for i, n in enumerate(prot_names_hgnc) if n in reactome_genes]
if not keep:
raise ValueError("No protein features matched Reactome after synonym mapping.")
sub = prot_ad[:, keep].copy()
feat_df = to_dense(sub)
feat_df.columns = [prot_names_hgnc[i] for i in keep]
feat_df = dedup_cols(feat_df) # average duplicate HGNC targets
else: # "mix"
rna_names_up = [str(c).upper() for c in rna_ad.var_names]
prot_names_hgnc = [resolve_protein_name(n) for n in prot_ad.var_names]
keep_rna = [i for i, n in enumerate(rna_names_up) if n in reactome_genes]
keep_prot = [i for i, n in enumerate(prot_names_hgnc) if n in reactome_genes]
rna_feat = to_dense(rna_ad[:, keep_rna].copy())
rna_feat.columns = [rna_names_up[i] for i in keep_rna]
prot_feat = to_dense(prot_ad[:, keep_prot].copy())
prot_feat.columns = [prot_names_hgnc[i] for i in keep_prot]
feat_df = dedup_cols(pd.concat([rna_feat, prot_feat], axis=1))
keep_feats = sorted(reactome_genes & set(feat_df.columns))
merged = pd.concat([feat_df[keep_feats], obs_df[obs_keep]], axis=1)
print(f" Feature matrix: {len(keep_feats)} features ({feature_set})")
return merged, obs_keep
# ── 1. Load + sample ──────────────────────────────────────────────────────────
_n_str = f"{SAMPLE_SIZE:,}" if SAMPLE_SIZE is not None else "all"
print(f"[1/4] Loading + sampling ({_n_str} cells)...")
mdata = mu.read("citeseq_logcounts.h5mu")
common_cells = sorted(set(mdata[RNA_MOD].obs_names) & set(mdata[PROT_MOD].obs_names))
if SAMPLE_SIZE is None:
sampled_cells = common_cells
else:
rng = np.random.default_rng(RANDOM_STATE)
sampled_cells = list(rng.choice(common_cells, size=SAMPLE_SIZE, replace=False))
rna_ad = mdata[RNA_MOD][sampled_cells].copy()
prot_ad = mdata[PROT_MOD][sampled_cells].copy()
# ── 2. Build Reactome map ─────────────────────────────────────────────────────
print("[2/4] Building Reactome map...")
reactome_net = build_reactome_network()
map_raw = get_map(reactome_net, n_levels=N_LEVELS).copy()
map_raw["layer0"] = map_raw["layer0"].astype(str).str.upper()
reactome_genes = set(map_raw["layer0"])
del reactome_net
# ── 3+4. Per-lineage feature matrices + graphs ────────────────────────────────
# Process one lineage at a time from the *sparse* AnnData slices so we never
# materialise the full N_cells × N_features dense matrix. Peak memory is now
# bounded by the largest single lineage (T cells ≈ 50% of cells) rather than
# all cells at once, which OOMs above ~80k cells at 9,661 features.
print(f"[3/4] Building feature matrices + graphs (FEATURE_SET='{FEATURE_SET}')...")
coarse = rna_ad.obs[LINEAGE_COL].astype(str)
pca_src = mdata[RNA_MOD].obsm["PCA"]
pca_df = pd.DataFrame(pca_src, index=mdata[RNA_MOD].obs_names.astype(str))
graph_data = {}
map_dfs = {}
obs_adatas = {}
pathway_map = None
for name, labels in LINEAGE_MAP.items():
mask = coarse.isin(labels)
rna_lin = rna_ad[mask].copy()
prot_lin = prot_ad[mask].copy()
merged_lin, obs_keep = build_feature_matrix(rna_lin, prot_lin, reactome_genes, FEATURE_SET)
del rna_lin, prot_lin; gc.collect()
adata_lin, pm = prepareAnnData(data=merged_lin, obs_vars=obs_keep, map=map_raw)
del merged_lin; gc.collect()
if pathway_map is None:
pathway_map = pm # identical for every lineage — compute once
positions = [int(i) for i in adata_lin.obs_names]
orig_ids = [sampled_cells[p] for p in positions]
adata_lin.obsm["PCA"] = pca_df.loc[orig_ids].to_numpy()
graph_data[name] = anndata_to_graph_data(adata_lin, group="Status", knn=5,
knn_fn=lambda a: a.obsm["PCA"])
map_dfs[name] = pm.copy()
obs_adatas[name] = adata_lin
print(f" {name}: {adata_lin.shape[0]:,} cells")
ref = next(iter(graph_data.values()))
# ── Protein-only reference graph (always built for benchmark 3) ───────────────
# Protein matrix is tiny (~224 features) so dense conversion is fine for all cells.
print(" Building protein-only reference graph...")
prot_names_hgnc = [resolve_protein_name(n) for n in prot_ad.var_names]
prot_df_full = to_dense(prot_ad)
del prot_ad; gc.collect()
prot_df_full.columns = prot_names_hgnc
prot_df_full = dedup_cols(prot_df_full)
reactome_prot_cols = sorted(c for c in prot_df_full.columns if c in reactome_genes)
prot_df_reactome = prot_df_full[reactome_prot_cols]
del prot_df_full; gc.collect()
print(f" Protein features matched to Reactome: {len(reactome_prot_cols)}")
prot_adatas = {}
for name, labels in LINEAGE_MAP.items():
mask = coarse.isin(labels)
orig_ids = [sampled_cells[p] for p in mask.values.nonzero()[0]]
sub_prot = prot_df_reactome.loc[orig_ids]
sub_adata = ad.AnnData(X=sub_prot.values.astype(np.float32),
obs=obs_adatas[name].obs.copy())
sub_adata.var_names = sub_prot.columns.tolist()
pca = PCA(n_components=min(30, len(reactome_prot_cols) - 1), random_state=RANDOM_STATE)
sub_adata.obsm["ProtPCA"] = pca.fit_transform(sub_prot.values)
prot_adatas[name] = sub_adata
del rna_ad; gc.collect()
# Protein pathway map — filter main map to Reactome-matched protein genes.
# pathway_map from prepareAnnData uses numeric column names ("0", "1", ...)
gene_col = pathway_map.columns[0] # first col is always the gene/feature level
map_prot = pathway_map[pathway_map[gene_col].isin(reactome_prot_cols)].copy()
if map_prot.empty:
# fallback: flat synthetic map
n_p = len(reactome_prot_cols)
map_prot = pd.DataFrame({
"layer0": reactome_prot_cols,
"layer1": [f"G{i % 50}" for i in range(n_p)],
"layer2": [f"G{i % 10}" for i in range(n_p)],
"layer3": [f"G{i % 2}" for i in range(n_p)],
})
graph_data_prot = {
name: anndata_to_graph_data(a, group="Status", knn=KNN,
knn_fn=lambda a: a.obsm["ProtPCA"])
for name, a in prot_adatas.items()
}
map_dfs_prot = {mod: map_prot.copy() for mod in NAMES}
obs_adatas_prot = {mod: prot_adatas[mod] for mod in NAMES}
ref_prot = next(iter(graph_data_prot.values()))
# ── Hyperparameters ───────────────────────────────────────────────────────────
base_config = Hyperparameters(
num_node_features=ref.num_node_features,
num_classes=2, lr=1e-3, w_decay=1e-5,
epochs=400, patience=50, min_delta=1e-4,
)
prot_config = Hyperparameters(
num_node_features=ref_prot.num_node_features,
num_classes=2, lr=1e-3, w_decay=1e-5,
epochs=400, patience=50, min_delta=1e-4,
)
# ── Run benchmarks ────────────────────────────────────────────────────────────
r_masked = r_free = r_prot = r_gat = r_ann = r_gat_free = r_ann_free = None
if 1 in RUN_BENCHMARKS:
print(f"\n=== Benchmark 1: BINN GCN (Reactome mask, {FEATURE_SET}) — {N_TRIALS} trials ===")
r_masked = run_trials(
N_TRIALS,
names=NAMES, graph_datas=graph_data, map_dfs=map_dfs,
obs_adatas=obs_adatas, config=base_config,
label_col="Status", donor_col="Donor",
use_pathway_mask=True, tag=f"BINN GCN (mask, {FEATURE_SET})",
)
if 2 in RUN_BENCHMARKS:
print(f"\n=== Benchmark 2: GCN unconstrained (no mask, {FEATURE_SET}) — {N_TRIALS} trials ===")
r_free = run_trials(
N_TRIALS,
names=NAMES, graph_datas=graph_data, map_dfs=map_dfs,
obs_adatas=obs_adatas, config=base_config,
label_col="Status", donor_col="Donor",
use_pathway_mask=False, tag=f"GCN (no mask, {FEATURE_SET})",
)
if 3 in RUN_BENCHMARKS:
print(f"\n=== Benchmark 3: Protein-only GCN — {N_TRIALS} trials ===")
r_prot = run_trials(
N_TRIALS,
names=NAMES, graph_datas=graph_data_prot, map_dfs=map_dfs_prot,
obs_adatas=obs_adatas_prot, config=prot_config,
label_col="Status", donor_col="Donor",
use_pathway_mask=True, tag="GCN protein-only (Reactome mask)",
)
if 4 in RUN_BENCHMARKS:
print(f"\n=== Benchmark 4: BINN GAT (Reactome mask, {FEATURE_SET}) — {N_TRIALS} trials ===")
r_gat = run_trials(
N_TRIALS,
names=NAMES_GAT, graph_datas=graph_data, map_dfs=map_dfs,
obs_adatas=obs_adatas, config=base_config,
label_col="Status", donor_col="Donor",
use_pathway_mask=True, tag=f"BINN GAT (mask, {FEATURE_SET})",
)
if 5 in RUN_BENCHMARKS:
print(f"\n=== Benchmark 5: BINN ANN (Reactome mask, {FEATURE_SET}) — {N_TRIALS} trials ===")
r_ann = run_trials(
N_TRIALS,
names=NAMES_ANN, graph_datas=graph_data, map_dfs=map_dfs,
obs_adatas=obs_adatas, config=base_config,
label_col="Status", donor_col="Donor",
use_pathway_mask=True, tag=f"BINN ANN (mask, {FEATURE_SET})",
)
if 6 in RUN_BENCHMARKS:
print(f"\n=== Benchmark 6: GAT unconstrained (no mask, {FEATURE_SET}) — {N_TRIALS} trials ===")
r_gat_free = run_trials(
N_TRIALS,
names=NAMES_GAT, graph_datas=graph_data, map_dfs=map_dfs,
obs_adatas=obs_adatas, config=base_config,
label_col="Status", donor_col="Donor",
use_pathway_mask=False, tag=f"GAT (no mask, {FEATURE_SET})",
)
if 7 in RUN_BENCHMARKS:
print(f"\n=== Benchmark 7: ANN unconstrained (no mask, {FEATURE_SET}) — {N_TRIALS} trials ===")
r_ann_free = run_trials(
N_TRIALS,
names=NAMES_ANN, graph_datas=graph_data, map_dfs=map_dfs,
obs_adatas=obs_adatas, config=base_config,
label_col="Status", donor_col="Donor",
use_pathway_mask=False, tag=f"ANN (no mask, {FEATURE_SET})",
)
# ── Summary ───────────────────────────────────────────────────────────────────
W = 72
print("\n" + "=" * W)
print(f"TIER-2 SUMMARY ({_n_str} cells, 5-fold donor CV, feature_set='{FEATURE_SET}')")
print("=" * W)
print(f" {'Model':<46} {'Acc':>10} {'AUC':>10}")
print(f" {'-'*46} {'-'*10} {'-'*10}")
for label, r in [
(f"BINN GCN (Reactome mask, {FEATURE_SET})", r_masked),
(f"GCN unconstrained (no mask, {FEATURE_SET})", r_free),
("GCN protein-only (Reactome proteins)", r_prot),
(f"BINN GAT (Reactome mask, {FEATURE_SET})", r_gat),
(f"GAT unconstrained (no mask, {FEATURE_SET})", r_gat_free),
(f"BINN ANN (Reactome mask, {FEATURE_SET})", r_ann),
(f"ANN unconstrained (no mask, {FEATURE_SET})", r_ann_free),
]:
if r is None:
continue
print(
f" {label:<46} "
f"{r['mean_acc']:.3f}±{r['std_acc']:.3f} "
f"{r['mean_auc']:.3f}±{r['std_auc']:.3f}"
)
print("=" * W)