From dc0ec1462469d5769f83ab7517e33c2ad451192c Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Tue, 2 Jun 2026 11:16:00 +0200 Subject: [PATCH] PseudobulkSpace: align extra obs columns by grouping keys When `groups_col` is provided, `ps_adata.obs.index` is a joined string like "patient_cluster", so reindexing the per-group lookup against that index returned NaN for every extra column. The all-NaN `obs` then made formulaic drop every row, producing an empty-index design matrix that `DeseqDataSet` rejected (scverse/pertpy#1003). Reindex by the grouping columns themselves instead. Co-Authored-By: Claude Opus 4.7 (1M context) --- pertpy/tools/_perturbation_space/_simple.py | 18 ++++++---- .../test_simple_perturbation_space.py | 33 +++++++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index 7c22854b..a809dd6c 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -147,14 +147,18 @@ def compute( ps_adata.X = ps_adata.layers[mode] missing_cols = [col for col in original_obs.columns if col not in ps_adata.obs.columns] - new_cols_data = {} - for col in missing_cols: - grouped_values = original_obs.groupby(grouping_cols, observed=False)[col].first() - new_cols_data[col] = grouped_values.reindex(ps_adata.obs.index).values - - if new_cols_data: - ps_adata.obs = pd.concat([ps_adata.obs, pd.DataFrame(new_cols_data, index=ps_adata.obs.index)], axis=1) + if missing_cols: + grouped = original_obs.groupby(grouping_cols, observed=False)[missing_cols].first() + # `ps_adata.obs` is indexed by joined group keys (e.g. "patient_cluster"), + # so reindex by the grouping columns themselves rather than the joined index. + if len(grouping_cols) == 1: + lookup = pd.Index(ps_adata.obs[grouping_cols[0]]) + else: + lookup = pd.MultiIndex.from_frame(ps_adata.obs[grouping_cols]) + aligned = grouped.reindex(lookup) + aligned.index = ps_adata.obs.index + ps_adata.obs = pd.concat([ps_adata.obs, aligned], axis=1) ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category") diff --git a/tests/tools/_perturbation_space/test_simple_perturbation_space.py b/tests/tools/_perturbation_space/test_simple_perturbation_space.py index 8a1a1627..aa998abb 100644 --- a/tests/tools/_perturbation_space/test_simple_perturbation_space.py +++ b/tests/tools/_perturbation_space/test_simple_perturbation_space.py @@ -93,6 +93,39 @@ def test_pseudobulk_response(adata_simple): ) +def test_pseudobulk_preserves_extra_obs_with_groups_col(rng): + """Regression test for https://github.com/scverse/pertpy/issues/1003. + + When `groups_col` is provided, the pseudobulk output's obs index is joined (e.g. "P0_C0"), + so reindexing extra obs columns must use the grouping columns as keys, not the joined index. + Otherwise every extra column ends up all-NaN and downstream tools (e.g. PyDESeq2) fail. + """ + patients = [f"P{i}" for i in range(4)] + clusters = [f"C{i}" for i in range(2)] + efficacy_per_patient = {"P0": "SD", "P1": "PR", "P2": "PD", "P3": "SD"} + n_cells = 200 + patient_choice = rng.choice(patients, size=n_cells) + cluster_choice = rng.choice(clusters, size=n_cells) + obs = pd.DataFrame( + { + "Patient": pd.Categorical(patient_choice, categories=patients), + "Cluster": pd.Categorical(cluster_choice, categories=clusters), + "Efficacy": pd.Categorical( + [efficacy_per_patient[p] for p in patient_choice], categories=["SD", "PR", "PD"] + ), + } + ) + X = rng.poisson(5, size=(n_cells, 10)).astype(float) + adata = AnnData(X=X, obs=obs) + + ps = pt.tl.PseudobulkSpace() + pdata = ps.compute(adata, target_col="Patient", groups_col="Cluster", mode="sum") + + assert pdata.obs["Efficacy"].isna().sum() == 0 + for row in pdata.obs.itertuples(): + assert row.Efficacy == efficacy_per_patient[row.Patient] + + def test_centroid_umap_response(): X = np.zeros((10, 5))