Skip to content
Merged
30 changes: 26 additions & 4 deletions src/spatialdata_plot/pl/_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,19 @@ def _ds_shade_categorical(
return _apply_user_alpha(shaded, alpha)


def _color_vector_is_uniform(color_vector: Any) -> bool:
"""True if every entry of the per-point colour vector is identical (so per-point colouring collapses).

Shared by the datashader categorical collapse and the matplotlib scalar-colour fast path.
"""
if color_vector is None or len(color_vector) == 0:
return False
arr = np.asarray(color_vector)
if arr.dtype.kind in "US": # fixed-width strings (the resolved-hex case): cheap vectorised compare
return bool((arr == arr[0]).all())
return pd.Series(arr).nunique(dropna=False) == 1 # object/other: hash-based, no sort


def _shade_datashader_aggregate(
cvs: ds.Canvas,
frame: Any,
Expand Down Expand Up @@ -365,6 +378,14 @@ def _shade_datashader_aggregate(
element-specific prep (geometry transform / point parse), outline rendering, ``_render_ds_image``
and ``_build_ds_colorbar``.
"""
# Single-colour collapse: when every point resolves to the same colour (e.g. past scanpy's
# 102-colour palette all categories are uniform grey), the per-category ds.by aggregate + composite
# is wasted and byte-identical to a plain single-colour count render. Detect it from the colour
# vector and route to the cheap count path — ~14x faster on high-cardinality categoricals (Xenium
# points coloured by gene). _ds_shade_categorical then colours the count by color_vector[0].
if color_by_categorical and _color_vector_is_uniform(color_vector):
col_for_color, color_by_categorical = None, False

agg, reduction_bounds, nan_agg = _ds_aggregate(
cvs, frame, col_for_color, color_by_categorical, ds_reduction, default_reduction, kind
)
Expand All @@ -376,15 +397,16 @@ def _shade_datashader_aggregate(

if (
strip_alpha_hex
and col_for_color is not None # no-color/collapse: _ds_shade_categorical strips color_vector[0] itself
and color_vector is not None
and len(color_vector) > 0
and isinstance(color_vector[0], str)
and color_vector[0].startswith("#")
):
# color_vector usually holds only a few distinct hex strings (one per category), so strip
# alpha on the unique values and map back rather than parsing once per point.
unique_hex, inverse = np.unique(color_vector, return_inverse=True)
color_vector = np.asarray([_hex_no_alpha(c) for c in unique_hex])[inverse]
# Strip alpha on the unique colours and map back rather than parsing once per point; pd.factorize
# dedups in O(n) (hash, no sort) where np.unique would sort millions of strings.
codes, uniques = pd.factorize(np.asarray(color_vector))
color_vector = np.asarray([_hex_no_alpha(c) for c in uniques])[codes]

# density without a color column collapses to a sequential count gradient; everything else with no
# explicit continuous value (categorical or no color) goes through the categorical shader.
Expand Down
78 changes: 19 additions & 59 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
import matplotlib.ticker
import numpy as np
import pandas as pd
import scanpy as sc
import spatialdata as sd
import xarray as xr
from anndata import AnnData
from matplotlib import patheffects
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Colormap, ListedColormap, Normalize
from scanpy._settings import settings as sc_settings
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
from spatialdata import get_extent, get_values
from spatialdata._core.query.relational_query import match_table_to_element
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
from spatialdata.transformations import set_transformation
from spatialdata.transformations.transformations import Identity
Expand All @@ -38,7 +35,6 @@
_get_colors_for_categorical_obs,
_get_linear_colormap,
_map_color_seg,
_maybe_set_colors,
_prepare_cmap_norm,
_resolve_continuous_norm,
resolve_color,
Expand All @@ -48,6 +44,7 @@
_ax_show_and_transform,
_build_ds_colorbar,
_circle_buffer_quad_segs,
_color_vector_is_uniform,
_datashader_canvas_from_dataframe,
_get_extent_and_range_for_datashader_canvas,
_hex_no_alpha,
Expand Down Expand Up @@ -333,7 +330,6 @@ def _add_legend_and_colorbar(
ax: matplotlib.axes.SubplotBase,
cax: ScalarMappable | None,
fig_params: FigParams,
adata: AnnData | None,
col_for_color: str | None,
color_spec: ColorSpec,
palette: ListedColormap | list[str] | None,
Expand Down Expand Up @@ -387,7 +383,6 @@ def _add_legend_and_colorbar(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=adata,
value_to_plot=col_for_color,
color_source_vector=color_source_vector,
color_vector=color_vector,
Expand Down Expand Up @@ -755,7 +750,6 @@ def _draw_centroids(xy: np.ndarray, radius: float | None = None) -> None:
color_spec=color_spec,
norm=norm,
na_color=render_params.cmap_params.na_color,
adata=table,
col_for_color=col_for_color,
palette=palette,
fig_params=fig_params,
Expand Down Expand Up @@ -1025,7 +1019,6 @@ def _draw_centroids(xy: np.ndarray, radius: float | None = None) -> None:
ax=ax,
cax=cax,
fig_params=fig_params,
adata=table,
col_for_color=col_for_color,
color_spec=color_spec,
palette=palette,
Expand Down Expand Up @@ -1059,18 +1052,26 @@ def _scatter_points(
Shared scatter primitive for points and the centroid "fast mode" of shapes/labels;
``color_vector`` is per-point hex strings or numeric values mapped through ``cmap``/``norm``.
"""
# When every marker is the same resolved hex colour (no-color / single colour / collapsed grey), pass a
# scalar ``color=`` instead of a per-point ``c=`` array: matplotlib then skips its per-point colour
# machinery — the dominant cost at scale (10M points: ~9s -> ~3.7s) — for a visually identical result.
# Numeric vectors keep the ``c=``/``cmap``/``norm`` path (they need the colormap).
cv = np.asarray(color_vector)
color_kwargs: dict[str, Any]
if cv.ndim == 1 and cv.dtype.kind in "US" and _color_vector_is_uniform(cv):
color_kwargs = {"color": str(cv[0])}
else:
color_kwargs = {"c": color_vector, "cmap": cmap, "norm": norm}
return ax.scatter(
x,
y,
s=size,
c=color_vector,
rasterized=sc_settings._vector_friendly,
cmap=cmap,
norm=norm,
alpha=alpha,
transform=trans_data,
zorder=zorder,
plotnonfinite=True, # nan points should be rendered as well
**color_kwargs,
)


Expand Down Expand Up @@ -1110,7 +1111,6 @@ def _render_centroids_as_points(
color_spec: ColorSpec,
norm: Normalize | None,
na_color: Any,
adata: AnnData | None,
col_for_color: str | None,
palette: Any,
fig_params: FigParams,
Expand Down Expand Up @@ -1170,7 +1170,6 @@ def _render_centroids_as_points(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=adata,
col_for_color=col_for_color,
color_spec=color_spec,
palette=palette,
Expand Down Expand Up @@ -1382,50 +1381,15 @@ def _render_points(
else points_pd_with_color
)

# we construct an anndata to hack the plotting functions
if table_name is None:
adata = AnnData(
X=points[["x", "y"]].values,
obs=points[coords],
dtype=points[["x", "y"]].values.dtype,
)
else:
matched_table = match_table_to_element(sdata=sdata, element_name=element, table_name=table_name)
adata_obs = matched_table.obs.copy()
# if the points are colored by values in X (or a different layer), add the values to obs
if col_for_color in matched_table.var_names:
if table_layer is None:
adata_obs[col_for_color] = matched_table[:, col_for_color].X.flatten()
else:
adata_obs[col_for_color] = matched_table[:, col_for_color].layers[table_layer].flatten()
adata = AnnData(
X=points[["x", "y"]].values,
obs=adata_obs,
dtype=points[["x", "y"]].values.dtype,
uns=matched_table.uns,
)
sdata_filt[table_name] = adata

# we can modify the sdata because of dealing with a copy
# Color (from a points column, table obs, or table var/X) is already materialized on `points` via
# the get_values merge above; coordinates and the legend come straight from `points`/`color_spec`,
# so no AnnData round-trip is needed. resolve_color reads the original table (kept in sdata_filt)
# for any user-defined uns palette.

# Convert back to dask dataframe to modify sdata
transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system]
_reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system, col_for_color)

if col_for_color is not None:
assert isinstance(col_for_color, str)
cols = sc.get.obs_df(adata, [col_for_color])
# maybe set color based on type
if isinstance(cols[col_for_color].dtype, pd.CategoricalDtype):
uns_color_key = f"{col_for_color}_colors"
if uns_color_key in adata.uns:
_maybe_set_colors(
source=adata,
target=adata,
key=col_for_color,
palette=palette,
)

# when user specified a single color, we emulate the form of `na_color` and use it
default_color = (
render_params.color if col_for_color is None and color is not None else render_params.cmap_params.na_color
Expand Down Expand Up @@ -1478,9 +1442,8 @@ def _render_points(
color_spec = color_spec.filter(keep)
if int(keep.sum()) == 0:
return
# filter the materialized points, adata, and re-register in sdata_filt
# filter the materialized points and re-register in sdata_filt
points = points[keep].reset_index(drop=True)
adata = adata[keep]
_reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system, col_for_color)

color_spec = color_spec.apply_transfunc(render_params.transfunc)
Expand Down Expand Up @@ -1550,8 +1513,8 @@ def _render_points(
update_parameters = not _mpl_ax_contains_elements(ax)
cax = _scatter_points(
ax,
adata[:, 0].X.flatten(),
adata[:, 1].X.flatten(),
points["x"].to_numpy(),
points["y"].to_numpy(),
color_spec.color_vector,
size=render_params.size,
cmap=render_params.cmap_params.cmap,
Expand All @@ -1570,7 +1533,6 @@ def _render_points(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=adata,
col_for_color=col_for_color,
color_spec=color_spec,
palette=None,
Expand Down Expand Up @@ -2409,7 +2371,6 @@ def _render_labels(
),
norm=render_params.cmap_params.fresh_norm(), # ax.scatter autoscales in place; don't mutate the shared norm
na_color=na_color,
adata=table if table_name is not None else None,
col_for_color=col_for_color,
palette=palette,
fig_params=fig_params,
Expand Down Expand Up @@ -2510,7 +2471,6 @@ def _draw_labels(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=table,
col_for_color=col_for_color,
color_spec=color_spec,
palette=palette,
Expand Down
16 changes: 15 additions & 1 deletion src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pandas.api.types import CategoricalDtype, is_numeric_dtype
from pandas.core.arrays.categorical import Categorical
from scanpy import settings
from scanpy.plotting import palettes
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
from spatialdata import (
SpatialData,
Expand Down Expand Up @@ -447,14 +448,19 @@ def _stack_categorical_legend(
new_leg._sdata_column = column # type: ignore[attr-defined]


# A per-entry legend past this many categories is unreadable, and scanpy builds it in O(categories^2)
# (one autoscaling artist each), dominating the render — so skip it with a warning. Tied to scanpy's
# default_102 palette, beyond which its *default* colors also stop being distinguishable (uniform grey).
_MAX_LEGEND_CATEGORIES = len(palettes.default_102)


def _decorate_axs(
ax: Axes,
cax: PatchCollection,
fig_params: FigParams,
value_to_plot: str | None,
color_source_vector: pd.Series[CategoricalDtype] | Categorical,
color_vector: pd.Series[CategoricalDtype] | Categorical,
adata: AnnData | None = None,
palette: ListedColormap | str | list[str] | None = None,
alpha: float = 1.0,
na_color: Color = Color("default"),
Expand Down Expand Up @@ -499,6 +505,14 @@ def _decorate_axs(
already = any(tagged)
if legend_loc in (None, "none"):
pass # legend suppressed
elif len(clusters) > _MAX_LEGEND_CATEGORIES:
# A per-entry legend this large is unreadable and scanpy builds it in O(categories^2)
# (one autoscaling artist each), dominating the render. Skip it.
logger.warning(
f"Skipping the categorical legend for '{value_to_plot}': {len(clusters)} categories "
f"exceed the {_MAX_LEGEND_CATEGORIES}-entry limit (unreadable and very slow to build). "
f"Pass a `groups` subset to get a legend."
)
elif already:
na_hex = na_color.get_hex() if (na_in_legend and pd.isnull(color_source_vector).any()) else None
_stack_categorical_legend(
Expand Down
Binary file modified tests/_images/Points_can_stack_render_points.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Points_datashader_matplotlib_stack.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading