diff --git a/docs/plotting.rst b/docs/plotting.rst index 2455b3bc..eb148534 100644 --- a/docs/plotting.rst +++ b/docs/plotting.rst @@ -7,3 +7,6 @@ Plotting (`.pl`) .. automodule:: spatialdata_plot.pl.basic :members: + +.. autoclass:: spatialdata_plot.PercentileNormalize + :members: diff --git a/src/spatialdata_plot/__init__.py b/src/spatialdata_plot/__init__.py index 76c7d18c..3fe95f2f 100644 --- a/src/spatialdata_plot/__init__.py +++ b/src/spatialdata_plot/__init__.py @@ -3,7 +3,8 @@ from . import pl from ._logging import set_verbosity from ._settings import Verbosity +from .pl._color import PercentileNormalize -__all__ = ["pl", "set_verbosity", "Verbosity"] +__all__ = ["PercentileNormalize", "Verbosity", "pl", "set_verbosity"] __version__ = version("spatialdata-plot") diff --git a/src/spatialdata_plot/pl/_color.py b/src/spatialdata_plot/pl/_color.py index b3e3425a..f1e31797 100644 --- a/src/spatialdata_plot/pl/_color.py +++ b/src/spatialdata_plot/pl/_color.py @@ -110,31 +110,81 @@ def _make_continuous_mappable(vmin: float, vmax: float, cmap: Any) -> ScalarMapp return ScalarMappable(norm=Normalize(vmin=vmin, vmax=vmax), cmap=cmap) +class PercentileNormalize(Normalize): + """:class:`~matplotlib.colors.Normalize` that autoscales to data percentiles instead of min/max. + + Heavy-tailed images (fluorescence, Xenium morphology) have a few very bright pixels, so the + default min/max mapping crushes the bulk of the signal into near-black. ``PercentileNormalize`` + derives ``vmin``/``vmax`` from the ``pmin``/``pmax`` percentiles of the data instead, which + matches the per-channel contrast limits used by viewers like Xenium Explorer. + + It plugs into the existing ``norm`` argument like any other ``Normalize``: a single instance is + autoscaled independently per channel, and a list applies channelwise limits. + + Parameters + ---------- + pmin + Lower percentile in ``[0, 100]`` (``pmin < pmax``) used to derive ``vmin``. + pmax + Upper percentile in ``[0, 100]`` used to derive ``vmax``. + clip + Forwarded to :class:`~matplotlib.colors.Normalize`. + + Notes + ----- + Explicitly setting ``vmin``/``vmax`` overrides the corresponding percentile. On the datashader + backend contrast autoscales to the aggregate range rather than to these percentiles. + """ + + def __init__(self, pmin: float = 0.0, pmax: float = 100.0, clip: bool = False) -> None: + if not 0.0 <= pmin < pmax <= 100.0: + raise ValueError(f"Require 0 <= pmin < pmax <= 100, got pmin={pmin}, pmax={pmax}.") + super().__init__(vmin=None, vmax=None, clip=clip) + self.pmin = pmin + self.pmax = pmax + + def autoscale_None(self, A: Any) -> None: + """Fill unset ``vmin``/``vmax`` from the ``pmin``/``pmax`` percentiles of finite values.""" + finite = np.ma.masked_invalid(np.ma.asarray(A)).compressed() # drops mask + NaN/inf + if finite.size: + if self.vmin is None: + self.vmin = float(np.percentile(finite, self.pmin)) + if self.vmax is None: + self.vmax = float(np.percentile(finite, self.pmax)) + + def _resolve_continuous_norm(values: Any, cmap_params: CmapParams) -> Normalize: """Resolve ``cmap_params.norm`` with concrete vmin/vmax for continuous coloring. - Honor explicit ``norm`` vmin/vmax, else the finite-value data range of ``values``, else - ``[0, 1]``. Shared by the pixel and colorbar sites so both derive the same range. Preserves the - norm subclass (``LogNorm``/``PowerNorm``/...) so non-linear scaling is not silently linearized. + Honor explicit ``norm`` vmin/vmax, else delegate to the norm's own ``autoscale_None`` over the + finite values of ``values`` (so plain ``Normalize`` uses min/max, ``LogNorm`` uses its + positive-only range, and ``PercentileNormalize`` uses percentiles), else fall back to ``[0, 1]``. + Shared by the pixel and colorbar sites so both derive the same range; preserves the norm + subclass so non-linear scaling is not silently linearized. """ - base = cmap_params.norm - vmin, vmax = base.vmin, base.vmax - if vmin is None or vmax is None: + resolved = copy(cmap_params.norm) + if resolved.vmin is None or resolved.vmax is None: arr = np.asarray(values) if not np.issubdtype(arr.dtype, np.number): arr = pd.to_numeric(arr.ravel(), errors="coerce") - finite = np.isfinite(arr) - data_min = float(np.nanmin(arr[finite])) if finite.any() else 0.0 - data_max = float(np.nanmax(arr[finite])) if finite.any() else 1.0 - if vmin is None: - vmin = data_min - if vmax is None: - vmax = data_max - if vmin == vmax and not isinstance(base, LogNorm): - # degenerate range collapses the cmap onto its floor; fall back to [0, 1]. LogNorm exempt (0 not in domain). - vmin, vmax = 0.0, 1.0 - resolved = copy(base) - resolved.vmin, resolved.vmax = vmin, vmax + finite = arr[np.isfinite(arr)] + if finite.size: + resolved.autoscale_None(finite) + if isinstance(resolved, LogNorm): + # LogNorm needs strictly-positive bounds; all-nonpositive/empty data can't provide them + # (matplotlib leaves them at 0), so fall back to a valid domain instead of raising later. + if resolved.vmin is None or resolved.vmin <= 0: + resolved.vmin = 1.0 + if resolved.vmax is None or resolved.vmax <= 0: + resolved.vmax = 1.0 + else: + if resolved.vmin is None: + resolved.vmin = 0.0 + if resolved.vmax is None: + resolved.vmax = 1.0 + if resolved.vmin == resolved.vmax and not isinstance(resolved, LogNorm): + # a single distinct value would collapse the cmap onto its floor + resolved.vmin, resolved.vmax = 0.0, 1.0 return resolved diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 3001a317..7f5102bb 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -768,6 +768,10 @@ def render_images( A single :class:`~matplotlib.colors.Normalize` applies to all channels. A list of :class:`~matplotlib.colors.Normalize` objects applies per-channel (length must match the number of channels). + For heavy-tailed images (e.g. fluorescence/Xenium morphology) where min/max + scaling looks dim, pass :class:`~spatialdata_plot.PercentileNormalize` to clip each + channel to a percentile range (single instance for all channels, or a list for + channelwise limits). palette : list[str] | str | None Palette to color images. Can be a single palette name (broadcast to all channels) or a list matching the number of channels. diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 1ab4825a..2e2cab8c 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -34,6 +34,7 @@ from spatialdata_plot._logging import _log_context, logger from spatialdata_plot.pl._color import ( ColorSpec, + ColorType, _get_colors_for_categorical_obs, _get_linear_colormap, _map_color_seg, @@ -697,8 +698,7 @@ def _render_shapes( nan_count = int(pd.isna(cv).sum()) if nan_count: logger.warning( - f"Found {nan_count} NaN values in color data. " - "These observations will be colored with the 'na_color'." + f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'." ) color_spec = color_spec.evolve(color_vector=cv) @@ -968,9 +968,11 @@ def _render_shapes( path.vertices = trans.transform(path.vertices) if color_spec.is_continuous: - # Colorbar range from the same resolved norm the fill pixels use. + # Colorbar uses the same resolved norm the fill pixels use, including its subclass + # (LogNorm/PowerNorm) — set_norm, not set_clim, which would leave the collection's + # default linear Normalize in place and mis-scale the bar for non-linear norms. used_norm = _resolve_continuous_norm(color_spec.color_vector, render_params.cmap_params) - _cax.set_clim(vmin=used_norm.vmin, vmax=used_norm.vmax) + _cax.set_norm(used_norm) _add_legend_and_colorbar( ax=ax, @@ -1413,8 +1415,6 @@ def _render_points( trans, trans_data = _prepare_transformation(sdata.points[element], coordinate_system, ax) - norm = render_params.cmap_params.fresh_norm() - method = render_params.method if render_params.density: @@ -1426,6 +1426,9 @@ def _render_points( _default_reduction: _DsReduction = "sum" if method == "datashader": + # datashader colors the per-pixel aggregate (count/sum/reduction), not the per-point vector, + # so pass an un-resolved norm and let _apply_ds_norm autoscale to the aggregate. + norm = render_params.cmap_params.fresh_norm() _log_datashader_method(method, render_params.ds_reduction, _default_reduction) # Apply transformations and materialize to pandas immediately so @@ -1462,6 +1465,13 @@ def _render_points( color_spec = color_spec.evolve(source_vector=csv, color_vector=cv) elif method == "matplotlib": + # matplotlib colors each point by its own value, so resolve the norm to match shapes/labels + # instead of letting ax.scatter autoscale a fresh one. Non-continuous keeps the fresh norm. + norm = ( + _resolve_continuous_norm(color_spec.color_vector, render_params.cmap_params) + if color_spec.is_continuous + else render_params.cmap_params.fresh_norm() + ) # update axis limits if plot was empty before (necessary if datashader comes after) update_parameters = not _mpl_ax_contains_elements(ax) cax = _scatter_points( @@ -2297,15 +2307,18 @@ def _render_labels( # (`_map_color_seg` Case C) instead of collapsing every dot to a single na_color. point_color_vector = np.random.default_rng(42).random((len(point_ids), 3)) point_color_source_vector = None + point_colortype: ColorType = "none" # colour is not data-driven allow_datashader = False elif len(color_spec.color_vector) == len(instance_id): - # data-driven colour is per-instance + # data-driven colour is per-instance; carry the upstream classification (invariant under mask) point_color_vector = np.asarray(color_spec.color_vector)[keep] point_color_source_vector = None if color_spec.source_vector is None else color_spec.source_vector[keep] + point_colortype = color_spec.colortype else: # literal colour / user-set na_color -> one colour per centroid point_color_vector = np.full(len(point_ids), na_color.get_hex_with_alpha()) point_color_source_vector = None + point_colortype = "none" # colour is not data-driven # transform rendered-raster intrinsic centroids to coordinate-system coords xy = trans.transform(np.column_stack([centroids["x"].to_numpy(), centroids["y"].to_numpy()])) _render_centroids_as_points( @@ -2313,9 +2326,10 @@ def _render_labels( render_params, x=xy[:, 0], y=xy[:, 1], - # point colours are derived fresh; classify by source so the spec stays self-consistent + # point colours are derived fresh; carry the resolved colortype so the spec invariant + # (categorical => pd.Categorical source) holds and `none` is not mislabelled categorical color_spec=ColorSpec( - "categorical" if point_color_source_vector is not None else "continuous", + point_colortype, point_color_source_vector, point_color_vector, ), @@ -2352,20 +2366,17 @@ def _draw_labels( outline_color_source_vector=outline_color_source_vector if seg_boundaries else None, ) - # labels is pre-baked RGB; cmap/norm only drive the colorbar, so feed the same resolved norm. - cax = ax.imshow( - labels, - rasterized=True, - cmap=None if color_spec.is_categorical else render_params.cmap_params.cmap, - norm=None - if color_spec.is_categorical - else _resolve_continuous_norm(color_spec.color_vector, render_params.cmap_params), - alpha=alpha, - origin="lower", - zorder=render_params.zorder, - ) - cax.set_transform(trans_data) - return cax + # labels is pre-baked RGB, so imshow ignores cmap/norm for display. Passing the resolved + # norm to imshow would make it try to normalize the RGBA array — which raises for a + # non-linear norm (LogNorm/PowerNorm). Display the RGB without a norm and build the + # continuous colorbar mappable separately from the resolved norm (mirrors the outline path), + # so the colorbar reflects the real norm subclass. + img = ax.imshow(labels, rasterized=True, alpha=alpha, origin="lower", zorder=render_params.zorder) + img.set_transform(trans_data) + if color_spec.is_categorical: + return img + used_norm = _resolve_continuous_norm(color_spec.color_vector, render_params.cmap_params) + return ScalarMappable(norm=used_norm, cmap=render_params.cmap_params.cmap) # When color is a literal (col_for_color is None) and no explicit outline_color, # use the literal color for outlines so they are visible (e.g., color='white' on diff --git a/tests/_images/Images_percentile_normalize_broadcast.png b/tests/_images/Images_percentile_normalize_broadcast.png new file mode 100644 index 00000000..283d26f4 Binary files /dev/null and b/tests/_images/Images_percentile_normalize_broadcast.png differ diff --git a/tests/_images/Images_percentile_normalize_channelwise.png b/tests/_images/Images_percentile_normalize_channelwise.png new file mode 100644 index 00000000..56df1942 Binary files /dev/null and b/tests/_images/Images_percentile_normalize_channelwise.png differ diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 440f4805..643e8575 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -10,6 +10,7 @@ from spatialdata.models import Image2DModel, Image3DModel import spatialdata_plot # noqa: F401 +from spatialdata_plot import PercentileNormalize from spatialdata_plot._logging import logger, logger_no_warns, logger_warns from spatialdata_plot.pl.render import _is_rgb_image from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over @@ -86,6 +87,15 @@ def test_plot_can_pass_normalize_clip_False(self, sdata_blobs: SpatialData): element="blobs_image", channel=0, norm=norm, cmap=_viridis_with_under_over() ).pl.show() + def test_plot_percentile_normalize_broadcast(self, sdata_blobs: SpatialData): + # single PercentileNormalize is broadcast and autoscaled per channel to its percentile range + sdata_blobs.pl.render_images(element="blobs_image", norm=PercentileNormalize(0, 90)).pl.show() + + def test_plot_percentile_normalize_channelwise(self, sdata_blobs: SpatialData): + # a list applies channelwise percentile limits + norms = [PercentileNormalize(0, 99), PercentileNormalize(0, 90), PercentileNormalize(0, 80)] + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1, 2], norm=norms).pl.show() + def test_plot_can_pass_color_to_single_channel(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images(element="blobs_image", channel=1, palette="red").pl.show() diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 2048d11c..5b7bd6cd 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -540,6 +540,45 @@ def test_render_labels_all_nan_color_renders_under_rasterize(sdata_blobs: Spatia plt.close(fig) +def test_render_labels_as_points_all_nan_color_does_not_crash(sdata_blobs: SpatialData): + # Regression: as_points rebuilds the centroid ColorSpec; an all-NaN column is the "none" colortype + # and must not be re-classified "categorical" (which would later call Categorical-only methods on + # the na-array source). It must just render the centroids in na_color. + labels_name = "blobs_labels" + instances = get_element_instances(sdata_blobs[labels_name]) + n_obs = len(instances) + adata = AnnData(np.zeros((n_obs, 1))) + adata.obs["instance_id"] = instances.values + adata.obs["nanvals"] = np.full(n_obs, np.nan) + adata.obs["region"] = labels_name + sdata_blobs["label_table"] = TableModel.parse( + adata=adata, region_key="region", instance_key="instance_id", region=labels_name + ) + fig, ax = plt.subplots() + sdata_blobs.pl.render_labels(labels_name, color="nanvals", table_name="label_table", as_points=True).pl.show(ax=ax) + plt.close(fig) + + +def test_render_labels_lognorm_with_zeros_does_not_crash(sdata_blobs: SpatialData): + # Regression: a continuous LogNorm column containing 0 must derive a positive vmin instead of a + # LogNorm(vmin=0) that raises "Invalid vmin or vmax" when the segmentation colors are mapped. + from matplotlib.colors import LogNorm + + labels_name = "blobs_labels" + instances = get_element_instances(sdata_blobs[labels_name]) + n_obs = len(instances) + adata = AnnData(np.zeros((n_obs, 1))) + adata.obs["instance_id"] = instances.values + adata.obs["counts"] = np.linspace(0.0, 10.0, n_obs) # includes 0 + adata.obs["region"] = labels_name + sdata_blobs["label_table"] = TableModel.parse( + adata=adata, region_key="region", instance_key="instance_id", region=labels_name + ) + fig, ax = plt.subplots() + sdata_blobs.pl.render_labels(labels_name, color="counts", table_name="label_table", norm=LogNorm()).pl.show(ax=ax) + plt.close(fig) + + @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) def test_render_labels_rejects_float_dtype(dtype): # Regression test for #606: float-dtype labels must raise a clear diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 565c856e..52a3f114 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -666,6 +666,18 @@ def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData): ).pl.show() +def test_render_points_lognorm_with_zeros_does_not_crash(sdata_blobs: SpatialData): + # Regression: matplotlib points resolve their continuous norm through the shared resolver, so a + # LogNorm column containing 0 must derive a positive vmin instead of a LogNorm(vmin=0) that raises. + from matplotlib.colors import LogNorm + + n = len(sdata_blobs["blobs_points"]) + sdata_blobs["blobs_points"]["counts"] = pd.Series(np.linspace(0.0, 10.0, n)) # includes 0 + fig, ax = plt.subplots() + sdata_blobs.pl.render_points("blobs_points", color="counts", norm=LogNorm(), method="matplotlib").pl.show(ax=ax) + plt.close(fig) + + @pytest.mark.parametrize("na_color", [None, "red"]) def test_groups_warns_when_no_groups_match_points(sdata_blobs: SpatialData, caplog, na_color): """Warning fires regardless of na_color when no groups match.""" diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 8552f61e..af5ff369 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -1150,6 +1150,29 @@ def test_render_shapes_all_nan_color_with_groups_does_not_crash(sdata_blobs_shap plt.close(fig) +def test_render_shapes_lognorm_with_zeros_does_not_crash(sdata_blobs_shapes_annotated: SpatialData): + # Regression: a continuous LogNorm column containing 0 must derive a positive vmin instead of + # producing a LogNorm(vmin=0) that raises "Invalid vmin or vmax" when the fill is mapped. + from matplotlib.colors import LogNorm + + sdata_blobs_shapes_annotated["blobs_polygons"]["counts"] = [0.0, 2.5, 5.0, 7.5, 10.0] + fig, ax = plt.subplots() + sdata_blobs_shapes_annotated.pl.render_shapes("blobs_polygons", color="counts", norm=LogNorm()).pl.show(ax=ax) + plt.close(fig) + + +def test_render_shapes_continuous_colorbar_reflects_norm_subclass(sdata_blobs_shapes_annotated: SpatialData): + # Regression: the fill colorbar must use the resolved norm subclass (LogNorm), not the + # collection's default linear Normalize — i.e. set_norm, not set_clim. + from matplotlib.colors import LogNorm + + sdata_blobs_shapes_annotated["blobs_polygons"]["counts"] = [1.0, 2.5, 5.0, 7.5, 10.0] + fig, ax = plt.subplots() + sdata_blobs_shapes_annotated.pl.render_shapes("blobs_polygons", color="counts", norm=LogNorm()).pl.show(ax=ax) + assert any(isinstance(c.norm, LogNorm) for c in ax.collections), "fill colorbar norm was linearized" + plt.close(fig) + + def test_gene_symbols_auto_detect_table(sdata_blobs: SpatialData): """gene_symbols resolves correctly without explicit table_name (#247).""" sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index ef4aa576..2a2caf32 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -917,6 +917,43 @@ def test_object_dtype_coerces_color_strings_to_nan(self): norm = _resolve_continuous_norm(np.array([1.0, "red", 9.0], dtype=object), self._params()) assert (norm.vmin, norm.vmax) == (1.0, 9.0) + def test_lognorm_with_zero_derives_positive_vmin_and_does_not_raise(self): + # regression: deriving vmin from a LogNorm's data range must skip 0/negatives; otherwise the + # preserved LogNorm has vmin <= 0 and raises "Invalid vmin or vmax" when the renderer calls it + from matplotlib.colors import LogNorm + + from spatialdata_plot.pl._color import _resolve_continuous_norm + + vals = np.array([0.0, 5.0, 50.0]) + params = CmapParams(cmap=self._params().cmap, norm=LogNorm(vmin=None, vmax=None), na_color=Color()) + norm = _resolve_continuous_norm(vals, params) + assert isinstance(norm, LogNorm) + assert norm.vmin == 5.0 and norm.vmax == 50.0 # smallest positive, not 0 + norm(vals) # must not raise + + def test_lognorm_with_negatives_derives_positive_vmin(self): + from matplotlib.colors import LogNorm + + from spatialdata_plot.pl._color import _resolve_continuous_norm + + vals = np.array([-10.0, 1.0, 100.0]) + params = CmapParams(cmap=self._params().cmap, norm=LogNorm(vmin=None, vmax=None), na_color=Color()) + norm = _resolve_continuous_norm(vals, params) + assert isinstance(norm, LogNorm) + assert norm.vmin == 1.0 and norm.vmax == 100.0 + norm(vals) # must not raise + + def test_lognorm_all_nonpositive_falls_back_without_raising(self): + from matplotlib.colors import LogNorm + + from spatialdata_plot.pl._color import _resolve_continuous_norm + + params = CmapParams(cmap=self._params().cmap, norm=LogNorm(vmin=None, vmax=None), na_color=Color()) + norm = _resolve_continuous_norm(np.array([0.0, -1.0, np.nan]), params) + assert isinstance(norm, LogNorm) + assert (norm.vmin, norm.vmax) == (1.0, 1.0) + norm(np.array([1.0])) # must not raise + def test_same_values_give_identical_norm_and_never_mutate_shared(self): # the core #699 invariant: pixels and colorbar call this with the same vector -> same result, # and the shared CmapParams.norm is never autoscaled in place. @@ -1089,3 +1126,89 @@ def test_precomputed_rgba_passthrough(self): arr = np.array([[1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]]) np.testing.assert_allclose(ColorSpec("continuous", None, arr).to_rgba(self._params()), arr) + + +class TestPercentileNormalize: + """PercentileNormalize + _resolve_continuous_norm (issue #370: dim multichannel renders).""" + + @staticmethod + def _three_channel_sdata() -> SpatialData: + from spatialdata.models import Image2DModel + from spatialdata.transformations import Identity + + arr = np.zeros((3, 8, 8), dtype=np.uint16) + img = Image2DModel.parse(arr, dims=("c", "y", "x"), transformations={"global": Identity()}) + return SpatialData(images={"img": img}) + + @pytest.mark.parametrize("norm_kind", ["normalize", "percentile"]) + def test_norm_single_or_list_must_match_channels(self, norm_kind): + from matplotlib.colors import Normalize + + from spatialdata_plot import PercentileNormalize + from spatialdata_plot.pl._validate import _validate_image_render_params + + make = (lambda: Normalize(0, 200)) if norm_kind == "normalize" else (lambda: PercentileNormalize(1, 99)) + sdata = self._three_channel_sdata() + kw = { + "element": "img", + "channel": None, + "alpha": 1.0, + "palette": None, + "cmap": None, + "scale": None, + "colorbar": True, + "colorbar_params": {}, + } + # a single norm (broadcast) and a length-matching list are accepted; a mismatched list is rejected + _validate_image_render_params(sdata, norm=make(), **kw) + _validate_image_render_params(sdata, norm=[make() for _ in range(3)], **kw) + with pytest.raises(ValueError, match="must match the number of channels"): + _validate_image_render_params(sdata, norm=[make(), make()], **kw) + + def test_autoscale_uses_percentiles(self): + from spatialdata_plot import PercentileNormalize + + # heavy-tailed: one huge outlier should not set vmax under p99 + data = np.concatenate([np.linspace(0, 100, 1000), [10000.0]]) + norm = PercentileNormalize(1, 99) + norm.autoscale_None(data) + assert norm.vmin == pytest.approx(np.percentile(data, 1)) + assert norm.vmax == pytest.approx(np.percentile(data, 99)) + assert norm.vmax < 10000.0 + + def test_explicit_vmin_vmax_override_percentiles(self): + from spatialdata_plot import PercentileNormalize + + norm = PercentileNormalize(1, 99) + norm.vmin, norm.vmax = 5.0, 50.0 + norm.autoscale_None(np.linspace(0, 100, 100)) + assert (norm.vmin, norm.vmax) == (5.0, 50.0) + + @pytest.mark.parametrize("pmin,pmax", [(99, 1), (-1, 50), (50, 101), (5, 5)]) + def test_invalid_percentiles_raise(self, pmin, pmax): + from spatialdata_plot import PercentileNormalize + + with pytest.raises(ValueError): + PercentileNormalize(pmin, pmax) + + def test_nan_and_masked_values_ignored(self): + from spatialdata_plot import PercentileNormalize + + norm = PercentileNormalize(0, 100) + norm.autoscale_None(np.array([np.nan, 1.0, 2.0, 3.0, np.inf, -np.inf])) + assert (norm.vmin, norm.vmax) == (1.0, 3.0) + # masked entries must not leak into the percentiles (mask honored like matplotlib's Normalize) + masked = PercentileNormalize(0, 100) + masked.autoscale_None(np.ma.masked_array([1.0, 2.0, 3.0, 1000.0], mask=[False, False, False, True])) + assert (masked.vmin, masked.vmax) == (1.0, 3.0) + + def test_resolve_honors_percentile_norm(self): + # _resolve_continuous_norm (colorbar/shapes path) must defer to the norm's percentile autoscale; + # min/max and degenerate/empty fallbacks for builtin norms are covered by TestResolveContinuousNorm. + from spatialdata_plot import PercentileNormalize + from spatialdata_plot.pl._color import _prepare_cmap_norm, _resolve_continuous_norm + + values = np.concatenate([np.linspace(0, 100, 1000), [10000.0]]) + resolved = _resolve_continuous_norm(values, _prepare_cmap_norm(norm=PercentileNormalize(0, 99))) + assert resolved.vmax == pytest.approx(np.percentile(values, 99)) + assert resolved.vmax < 10000.0