diff --git a/src/spatialdata_plot/pl/_datashader.py b/src/spatialdata_plot/pl/_datashader.py index b56a52fd..eb7bb42d 100644 --- a/src/spatialdata_plot/pl/_datashader.py +++ b/src/spatialdata_plot/pl/_datashader.py @@ -336,6 +336,19 @@ def _ds_shade_categorical( return _apply_user_alpha(shaded, alpha) +def _color_vector_is_uniform(color_vector: Any) -> bool: + """True if every entry of the per-point colour vector is identical (so per-point colouring collapses). + + Shared by the datashader categorical collapse and the matplotlib scalar-colour fast path. + """ + if color_vector is None or len(color_vector) == 0: + return False + arr = np.asarray(color_vector) + if arr.dtype.kind in "US": # fixed-width strings (the resolved-hex case): cheap vectorised compare + return bool((arr == arr[0]).all()) + return pd.Series(arr).nunique(dropna=False) == 1 # object/other: hash-based, no sort + + def _shade_datashader_aggregate( cvs: ds.Canvas, frame: Any, @@ -365,6 +378,14 @@ def _shade_datashader_aggregate( element-specific prep (geometry transform / point parse), outline rendering, ``_render_ds_image`` and ``_build_ds_colorbar``. """ + # Single-colour collapse: when every point resolves to the same colour (e.g. past scanpy's + # 102-colour palette all categories are uniform grey), the per-category ds.by aggregate + composite + # is wasted and byte-identical to a plain single-colour count render. Detect it from the colour + # vector and route to the cheap count path — ~14x faster on high-cardinality categoricals (Xenium + # points coloured by gene). _ds_shade_categorical then colours the count by color_vector[0]. + if color_by_categorical and _color_vector_is_uniform(color_vector): + col_for_color, color_by_categorical = None, False + agg, reduction_bounds, nan_agg = _ds_aggregate( cvs, frame, col_for_color, color_by_categorical, ds_reduction, default_reduction, kind ) @@ -376,15 +397,16 @@ def _shade_datashader_aggregate( if ( strip_alpha_hex + and col_for_color is not None # no-color/collapse: _ds_shade_categorical strips color_vector[0] itself and color_vector is not None and len(color_vector) > 0 and isinstance(color_vector[0], str) and color_vector[0].startswith("#") ): - # color_vector usually holds only a few distinct hex strings (one per category), so strip - # alpha on the unique values and map back rather than parsing once per point. - unique_hex, inverse = np.unique(color_vector, return_inverse=True) - color_vector = np.asarray([_hex_no_alpha(c) for c in unique_hex])[inverse] + # Strip alpha on the unique colours and map back rather than parsing once per point; pd.factorize + # dedups in O(n) (hash, no sort) where np.unique would sort millions of strings. + codes, uniques = pd.factorize(np.asarray(color_vector)) + color_vector = np.asarray([_hex_no_alpha(c) for c in uniques])[codes] # density without a color column collapses to a sequential count gradient; everything else with no # explicit continuous value (categorical or no color) goes through the categorical shader. diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index da27e257..61b43fd9 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -15,17 +15,14 @@ import matplotlib.ticker import numpy as np import pandas as pd -import scanpy as sc import spatialdata as sd import xarray as xr -from anndata import AnnData from matplotlib import patheffects from matplotlib.cm import ScalarMappable from matplotlib.colors import Colormap, ListedColormap, Normalize from scanpy._settings import settings as sc_settings from scanpy.plotting._tools.scatterplots import _add_categorical_legend from spatialdata import get_extent, get_values -from spatialdata._core.query.relational_query import match_table_to_element from spatialdata.models import PointsModel, ShapesModel, get_table_keys from spatialdata.transformations import set_transformation from spatialdata.transformations.transformations import Identity @@ -38,7 +35,6 @@ _get_colors_for_categorical_obs, _get_linear_colormap, _map_color_seg, - _maybe_set_colors, _prepare_cmap_norm, _resolve_continuous_norm, resolve_color, @@ -48,6 +44,7 @@ _ax_show_and_transform, _build_ds_colorbar, _circle_buffer_quad_segs, + _color_vector_is_uniform, _datashader_canvas_from_dataframe, _get_extent_and_range_for_datashader_canvas, _hex_no_alpha, @@ -333,7 +330,6 @@ def _add_legend_and_colorbar( ax: matplotlib.axes.SubplotBase, cax: ScalarMappable | None, fig_params: FigParams, - adata: AnnData | None, col_for_color: str | None, color_spec: ColorSpec, palette: ListedColormap | list[str] | None, @@ -387,7 +383,6 @@ def _add_legend_and_colorbar( ax=ax, cax=cax, fig_params=fig_params, - adata=adata, value_to_plot=col_for_color, color_source_vector=color_source_vector, color_vector=color_vector, @@ -755,7 +750,6 @@ def _draw_centroids(xy: np.ndarray, radius: float | None = None) -> None: color_spec=color_spec, norm=norm, na_color=render_params.cmap_params.na_color, - adata=table, col_for_color=col_for_color, palette=palette, fig_params=fig_params, @@ -1025,7 +1019,6 @@ def _draw_centroids(xy: np.ndarray, radius: float | None = None) -> None: ax=ax, cax=cax, fig_params=fig_params, - adata=table, col_for_color=col_for_color, color_spec=color_spec, palette=palette, @@ -1059,18 +1052,26 @@ def _scatter_points( Shared scatter primitive for points and the centroid "fast mode" of shapes/labels; ``color_vector`` is per-point hex strings or numeric values mapped through ``cmap``/``norm``. """ + # When every marker is the same resolved hex colour (no-color / single colour / collapsed grey), pass a + # scalar ``color=`` instead of a per-point ``c=`` array: matplotlib then skips its per-point colour + # machinery — the dominant cost at scale (10M points: ~9s -> ~3.7s) — for a visually identical result. + # Numeric vectors keep the ``c=``/``cmap``/``norm`` path (they need the colormap). + cv = np.asarray(color_vector) + color_kwargs: dict[str, Any] + if cv.ndim == 1 and cv.dtype.kind in "US" and _color_vector_is_uniform(cv): + color_kwargs = {"color": str(cv[0])} + else: + color_kwargs = {"c": color_vector, "cmap": cmap, "norm": norm} return ax.scatter( x, y, s=size, - c=color_vector, rasterized=sc_settings._vector_friendly, - cmap=cmap, - norm=norm, alpha=alpha, transform=trans_data, zorder=zorder, plotnonfinite=True, # nan points should be rendered as well + **color_kwargs, ) @@ -1110,7 +1111,6 @@ def _render_centroids_as_points( color_spec: ColorSpec, norm: Normalize | None, na_color: Any, - adata: AnnData | None, col_for_color: str | None, palette: Any, fig_params: FigParams, @@ -1170,7 +1170,6 @@ def _render_centroids_as_points( ax=ax, cax=cax, fig_params=fig_params, - adata=adata, col_for_color=col_for_color, color_spec=color_spec, palette=palette, @@ -1382,50 +1381,15 @@ def _render_points( else points_pd_with_color ) - # we construct an anndata to hack the plotting functions - if table_name is None: - adata = AnnData( - X=points[["x", "y"]].values, - obs=points[coords], - dtype=points[["x", "y"]].values.dtype, - ) - else: - matched_table = match_table_to_element(sdata=sdata, element_name=element, table_name=table_name) - adata_obs = matched_table.obs.copy() - # if the points are colored by values in X (or a different layer), add the values to obs - if col_for_color in matched_table.var_names: - if table_layer is None: - adata_obs[col_for_color] = matched_table[:, col_for_color].X.flatten() - else: - adata_obs[col_for_color] = matched_table[:, col_for_color].layers[table_layer].flatten() - adata = AnnData( - X=points[["x", "y"]].values, - obs=adata_obs, - dtype=points[["x", "y"]].values.dtype, - uns=matched_table.uns, - ) - sdata_filt[table_name] = adata - - # we can modify the sdata because of dealing with a copy + # Color (from a points column, table obs, or table var/X) is already materialized on `points` via + # the get_values merge above; coordinates and the legend come straight from `points`/`color_spec`, + # so no AnnData round-trip is needed. resolve_color reads the original table (kept in sdata_filt) + # for any user-defined uns palette. # Convert back to dask dataframe to modify sdata transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system] _reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system, col_for_color) - if col_for_color is not None: - assert isinstance(col_for_color, str) - cols = sc.get.obs_df(adata, [col_for_color]) - # maybe set color based on type - if isinstance(cols[col_for_color].dtype, pd.CategoricalDtype): - uns_color_key = f"{col_for_color}_colors" - if uns_color_key in adata.uns: - _maybe_set_colors( - source=adata, - target=adata, - key=col_for_color, - palette=palette, - ) - # when user specified a single color, we emulate the form of `na_color` and use it default_color = ( render_params.color if col_for_color is None and color is not None else render_params.cmap_params.na_color @@ -1478,9 +1442,8 @@ def _render_points( color_spec = color_spec.filter(keep) if int(keep.sum()) == 0: return - # filter the materialized points, adata, and re-register in sdata_filt + # filter the materialized points and re-register in sdata_filt points = points[keep].reset_index(drop=True) - adata = adata[keep] _reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system, col_for_color) color_spec = color_spec.apply_transfunc(render_params.transfunc) @@ -1550,8 +1513,8 @@ def _render_points( update_parameters = not _mpl_ax_contains_elements(ax) cax = _scatter_points( ax, - adata[:, 0].X.flatten(), - adata[:, 1].X.flatten(), + points["x"].to_numpy(), + points["y"].to_numpy(), color_spec.color_vector, size=render_params.size, cmap=render_params.cmap_params.cmap, @@ -1570,7 +1533,6 @@ def _render_points( ax=ax, cax=cax, fig_params=fig_params, - adata=adata, col_for_color=col_for_color, color_spec=color_spec, palette=None, @@ -2409,7 +2371,6 @@ def _render_labels( ), norm=render_params.cmap_params.fresh_norm(), # ax.scatter autoscales in place; don't mutate the shared norm na_color=na_color, - adata=table if table_name is not None else None, col_for_color=col_for_color, palette=palette, fig_params=fig_params, @@ -2510,7 +2471,6 @@ def _draw_labels( ax=ax, cax=cax, fig_params=fig_params, - adata=table, col_for_color=col_for_color, color_spec=color_spec, palette=palette, diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index f61615c8..8d12ae20 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -31,6 +31,7 @@ from pandas.api.types import CategoricalDtype, is_numeric_dtype from pandas.core.arrays.categorical import Categorical from scanpy import settings +from scanpy.plotting import palettes from scanpy.plotting._tools.scatterplots import _add_categorical_legend from spatialdata import ( SpatialData, @@ -447,6 +448,12 @@ def _stack_categorical_legend( new_leg._sdata_column = column # type: ignore[attr-defined] +# A per-entry legend past this many categories is unreadable, and scanpy builds it in O(categories^2) +# (one autoscaling artist each), dominating the render — so skip it with a warning. Tied to scanpy's +# default_102 palette, beyond which its *default* colors also stop being distinguishable (uniform grey). +_MAX_LEGEND_CATEGORIES = len(palettes.default_102) + + def _decorate_axs( ax: Axes, cax: PatchCollection, @@ -454,7 +461,6 @@ def _decorate_axs( value_to_plot: str | None, color_source_vector: pd.Series[CategoricalDtype] | Categorical, color_vector: pd.Series[CategoricalDtype] | Categorical, - adata: AnnData | None = None, palette: ListedColormap | str | list[str] | None = None, alpha: float = 1.0, na_color: Color = Color("default"), @@ -499,6 +505,14 @@ def _decorate_axs( already = any(tagged) if legend_loc in (None, "none"): pass # legend suppressed + elif len(clusters) > _MAX_LEGEND_CATEGORIES: + # A per-entry legend this large is unreadable and scanpy builds it in O(categories^2) + # (one autoscaling artist each), dominating the render. Skip it. + logger.warning( + f"Skipping the categorical legend for '{value_to_plot}': {len(clusters)} categories " + f"exceed the {_MAX_LEGEND_CATEGORIES}-entry limit (unreadable and very slow to build). " + f"Pass a `groups` subset to get a legend." + ) elif already: na_hex = na_color.get_hex() if (na_in_legend and pd.isnull(color_source_vector).any()) else None _stack_categorical_legend( diff --git a/tests/_images/Points_can_stack_render_points.png b/tests/_images/Points_can_stack_render_points.png index 8b008723..d38b0eb5 100644 Binary files a/tests/_images/Points_can_stack_render_points.png and b/tests/_images/Points_can_stack_render_points.png differ diff --git a/tests/_images/Points_datashader_matplotlib_stack.png b/tests/_images/Points_datashader_matplotlib_stack.png index 7016d8e1..a8ce1312 100644 Binary files a/tests/_images/Points_datashader_matplotlib_stack.png and b/tests/_images/Points_datashader_matplotlib_stack.png differ