diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 9c5584e3..fa9bfd39 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -1673,7 +1673,10 @@ def _resolve_coordinate_systems( if cs not in sdata.coordinate_systems: raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}") - elements_to_be_rendered = _get_elements_to_be_rendered(render_cmds, cs_index, cs) + # Union elements across all coordinate systems, not just the last one validated above. + elements_to_be_rendered = list( + dict.fromkeys(e for cs in coordinate_systems for e in _get_elements_to_be_rendered(render_cmds, cs_index, cs)) + ) # filter out cs without relevant elements cmds = [cmd for cmd, _ in render_cmds] diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index 74a929f4..33c07739 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -791,3 +791,26 @@ def test_missing_instances_become_nan(self): new = _extract_color_column(sdata["table"], "g0", origin="var", element=sdata["shapes"], element_name="shapes") assert len(new) == 30 assert int(new.isna().sum()) == 5 + + +def test_show_renders_all_coordinate_systems_for_distributed_elements(): + """Regression for #694: element types split across coordinate systems must all render. + + Before the fix, ``_resolve_coordinate_systems`` computed the rendered element set from a + leaked loop variable (the last validated CS only), so ``_get_valid_cs`` could drop a CS whose + elements lived elsewhere. ``_get_elements_to_be_rendered`` keys on element *type*, so the bug + only surfaces when the leftover CS lacks a queued element type: here the image lives in ``cs_a`` + and the labels in ``cs_b``, so the leaked ``cs_b`` (no image) used to drop ``cs_a``. + """ + from spatialdata.models import Image2DModel + from spatialdata.transformations import Identity + + rng = np.random.default_rng(0) + img = Image2DModel.parse(rng.random((3, 16, 16)), transformations={"cs_a": Identity()}) + lab = Labels2DModel.parse(rng.integers(0, 5, (16, 16)), transformations={"cs_b": Identity()}) + sdata = SpatialData(images={"img": img}, labels={"lab": lab}) + + axes = sdata.pl.render_images("img").pl.render_labels("lab").pl.show(return_ax=True) + axes = axes if isinstance(axes, list) else [axes] + assert {ax.get_title() for ax in axes} == {"cs_a", "cs_b"} + plt.close("all")