Skip to content

Commit 56ade41

Browse files
committed
Refactor tests and add new annotation methods (need to add tests).
1 parent 5296ae6 commit 56ade41

20 files changed

Lines changed: 171 additions & 20 deletions

CONTRIBUTING.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Ready to contribute? Here's how to set up `stlearn` for local development.
7070
$ conda create -n stlearn-dev python=3.12 --y
7171
$ conda activate stlearn-dev
7272
$ cd stlearn/
73-
$ pip install -e .[dev,test]
73+
$ pip install -e ".[dev,test]"
7474

7575
You can also use conda to install these dependencies (after creating the environment):
7676
$ conda install -c conda-forge leidenalg python-igraph
@@ -80,7 +80,7 @@ Ready to contribute? Here's how to set up `stlearn` for local development.
8080
$ python -m venv stlearn-env
8181
$ source stlearn-env/bin/activate # On Windows: stlearn-env\Scripts\activate
8282
$ cd stlearn/
83-
$ pip install -e .[dev,test]
83+
$ pip install -e ".[dev,test]"
8484

8585
4. Create a branch for local development::
8686

pyproject.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ readme = { file = "README.md", content-type = "text/markdown" }
1313
license = { text = "BSD license" }
1414
requires-python = ">=3.12"
1515
dependencies = [
16+
"anndata>=0.10.0,<0.12",
1617
"bokeh>=3.7.0,<4.0",
1718
"click>=8.2.0,<9.0",
1819
"igraph>=1.0.0",
1920
"leidenalg>=0.11.0",
2021
"numba>=0.58.1",
21-
"numpy>=1.26.0,<2.0",
22+
"numpy>=1.26.0",
2223
"pillow>=11.0.0,<12.0",
2324
"scanpy>=1.11.0,<2.0",
2425
"scikit-image>=0.22.0",
@@ -28,6 +29,10 @@ dependencies = [
2829
"imageio>=2.37.0,<3.0",
2930
"scipy>=1.11.0,<2.0",
3031
"scikit-learn>=1.7.0,<2.0",
32+
"spatialdata>=0.2.5,<0.3",
33+
"spatialdata-io>=0.1.5,<0.2",
34+
"geopandas>=1.0.0,<2.0",
35+
"shapely>=2.0.0,<3.0",
3136
]
3237
keywords = ["stlearn"]
3338
classifiers = [
@@ -86,9 +91,6 @@ include = ["stlearn", "stlearn.*"]
8691
[tool.setuptools.package-data]
8792
"*" = ["*"]
8893

89-
[tool.setuptools.dynamic]
90-
dependencies = { file = ["requirements.txt"] }
91-
9294
[tool.ruff]
9395
target-version = "py311"
9496
line-length = 88

stlearn/_compat.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import spatialdata as sd
2+
from anndata import AnnData
3+
4+
5+
def get_adata(data, table_key="table"):
6+
"""Extract AnnData from either SpatialData or AnnData input."""
7+
if isinstance(data, sd.SpatialData):
8+
return data.tables[table_key]
9+
elif isinstance(data, AnnData):
10+
return data
11+
else:
12+
raise TypeError(f"Expected SpatialData or AnnData, got {type(data)}")
13+
14+
15+
def is_spatial_data(data):
16+
return isinstance(data, sd.SpatialData)

stlearn/add.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from .adds.add_deconvolution import add_deconvolution
2-
from .adds.add_image import image
3-
from .adds.add_labels import labels
2+
from .adds.image import image
3+
from .adds.labels import labels
44
from .adds.add_loupe_clusters import add_loupe_clusters
5-
from .adds.add_lr import lr
5+
from .adds.lr import lr
66
from .adds.add_mask import add_mask, apply_mask
7-
from .adds.add_positions import positions
7+
from .adds.positions import positions
88
from .adds.annotation import annotation
99
from .adds.parsing import parsing
1010

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from pathlib import Path
2+
3+
import geopandas as gpd
4+
5+
from stlearn._compat import get_adata, is_spatial_data
6+
7+
8+
def polygon_annotations(
9+
data,
10+
annotations,
11+
label_column="label",
12+
obs_key="region",
13+
spatial_key="spatial",
14+
table_key="table",
15+
copy=False,
16+
):
17+
"""
18+
Annotate cells/spots by spatial overlap with polygon regions.
19+
20+
Parameters
21+
----------
22+
data
23+
SpatialData or AnnData object with spatial coordinates.
24+
annotations
25+
GeoDataFrame or path to GeoJSON/shapefile.
26+
...
27+
"""
28+
adata = get_adata(data, table_key)
29+
if copy:
30+
adata = adata.copy()
31+
32+
if isinstance(annotations, (str, Path)):
33+
annotations = gpd.read_file(annotations)
34+
35+
coords = adata.obsm[spatial_key]
36+
points = gpd.GeoDataFrame(
37+
index=adata.obs_names,
38+
geometry=gpd.points_from_xy(coords[:, 0], coords[:, 1]),
39+
)
40+
41+
joined = gpd.sjoin(points, annotations, how="left", predicate="within")
42+
joined = joined[~joined.index.duplicated(keep="first")]
43+
adata.obs[obs_key] = (
44+
joined.reindex(adata.obs_names)[label_column].astype("category").values
45+
)
46+
47+
# If SpatialData, also store the polygons as a shapes element
48+
if is_spatial_data(data):
49+
import spatialdata.models as models
50+
51+
parsed = models.ShapesModel.parse(annotations)
52+
data.shapes[obs_key] = parsed
53+
data.tables[table_key] = adata
54+
55+
n_annotated = adata.obs[obs_key].notna().sum()
56+
print(
57+
f"Added polygon annotations to adata.obs['{obs_key}']: "
58+
f"{n_annotated}/{adata.n_obs} cells/spots annotated"
59+
)
60+
61+
if is_spatial_data(data):
62+
return data
63+
return adata if copy else None

stlearn/adds/row_annotations.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from pathlib import Path
2+
3+
import pandas as pd
4+
from anndata import AnnData
5+
6+
7+
def row_annotations(
8+
adata: AnnData,
9+
annotations: pd.DataFrame | str | Path,
10+
join_column: str | None = None,
11+
columns: list[str] | None = None,
12+
copy: bool = False,
13+
) -> AnnData | None:
14+
"""\
15+
Add annotations to adata.obs by joining on cell/spot identifiers.
16+
17+
Merges a DataFrame (or CSV file) into adata.obs based on a
18+
shared index or column. Useful for adding metadata such as
19+
manual labels, clinical annotations, or external classifications.
20+
21+
Parameters
22+
----------
23+
adata
24+
Annotated data matrix.
25+
annotations
26+
DataFrame or path to a CSV/TSV file containing annotations.
27+
join_column
28+
Column in annotations to join on. If None, uses the
29+
DataFrame index. The join is always against adata.obs_names.
30+
columns
31+
Subset of columns to add. If None, adds all columns
32+
(excluding join_column).
33+
copy
34+
Return a copy instead of writing to adata.
35+
36+
Returns
37+
-------
38+
Depending on `copy`, returns or updates `adata` with new
39+
columns added to `adata.obs`.
40+
"""
41+
adata = adata.copy() if copy else adata
42+
43+
if isinstance(annotations, (str, Path)):
44+
path = Path(annotations)
45+
sep = "\t" if path.suffix in (".tsv", ".txt") else ","
46+
annotations = pd.read_csv(path, sep=sep)
47+
48+
if join_column is not None:
49+
if join_column not in annotations.columns:
50+
raise ValueError(
51+
f"Column '{join_column}' not found. "
52+
f"Available: {list(annotations.columns)}"
53+
)
54+
annotations = annotations.set_index(join_column)
55+
56+
if columns is not None:
57+
missing = [c for c in columns if c not in annotations.columns]
58+
if missing:
59+
raise ValueError(f"Columns not found: {missing}")
60+
annotations = annotations[columns]
61+
62+
merged = annotations.reindex(adata.obs_names)
63+
64+
n_matched = merged.notna().any(axis=1).sum()
65+
added_cols = list(merged.columns)
66+
67+
for col in added_cols:
68+
adata.obs[col] = merged[col].values
69+
70+
print(
71+
f"Added {len(added_cols)} column(s) to adata.obs: {added_cols}. "
72+
f"{n_matched}/{adata.n_obs} cells/spots matched."
73+
)
74+
75+
return adata if copy else None

0 commit comments

Comments
 (0)