diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index be6cd741..9448dbb8 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -30,6 +30,7 @@ from spatialdata_plot._accessor import register_spatial_data_accessor from spatialdata_plot._logging import _log_context, logger from spatialdata_plot.pl.render import ( + _draw_channel_legend, _render_images, _render_labels, _render_points, @@ -40,6 +41,7 @@ CBAR_DEFAULT_FRACTION, CBAR_DEFAULT_LOCATION, CBAR_DEFAULT_PAD, + ChannelLegendEntry, CmapParams, ColorbarSpec, ImageRenderParams, @@ -532,9 +534,10 @@ def render_images( alpha: float | int = 1.0, scale: str | None = None, grayscale: bool = False, - transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None, + transfunc: (Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None) = None, colorbar: bool | str | None = "auto", colorbar_params: dict[str, object] | None = None, + channels_as_legend: bool = False, ) -> sd.SpatialData: """ Render image elements in SpatialData. @@ -608,6 +611,13 @@ def render_images( colorbar_params : dict[str, object] | None Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``, and ``label``. + channels_as_legend : bool, default False + When ``True`` and rendering multiple channels, show a categorical + legend mapping each channel name to its compositing color. The + legend uses the ``legend_*`` parameters from :meth:`show`. + Ignored for single-channel and RGB(A) images. When multiple + ``render_images`` calls use this flag on the same axes, all + channel entries are combined into a single legend. Notes ----- @@ -690,6 +700,7 @@ def render_images( colorbar_params=param_values["colorbar_params"], transfunc=transfunc, grayscale=grayscale, + channels_as_legend=channels_as_legend, ) n_steps += 1 @@ -1194,6 +1205,7 @@ def _draw_colorbar( ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i] assert isinstance(ax, Axes) axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None + axis_channel_legend_entries: list[ChannelLegendEntry] = [] wants_images = False wants_labels = False @@ -1224,6 +1236,7 @@ def _draw_colorbar( scalebar_params=scalebar_params, legend_params=legend_params, colorbar_requests=axis_colorbar_requests, + channel_legend_entries=axis_channel_legend_entries, rasterize=rasterize, ) @@ -1270,7 +1283,10 @@ def _draw_colorbar( table = params_copy.table_name if table is not None and params_copy.col_for_color is not None: colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color]) - if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype): + if isinstance( + colors[params_copy.col_for_color].dtype, + pd.CategoricalDtype, + ): _maybe_set_colors( source=sdata[table], target=sdata[table], @@ -1333,6 +1349,9 @@ def _draw_colorbar( if legend_params.colorbar and axis_colorbar_requests: pending_colorbars.append((ax, axis_colorbar_requests)) + if axis_channel_legend_entries: + _draw_channel_legend(ax, axis_channel_legend_entries, legend_params, fig_params) + if pending_colorbars and fig_params.fig is not None: fig = fig_params.fig fig.canvas.draw() diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 8cfb72e9..892dbf6a 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -2,6 +2,7 @@ import dataclasses from collections import abc +from collections.abc import Sequence from copy import copy from typing import Any @@ -18,9 +19,11 @@ import spatialdata as sd import xarray as xr from anndata import AnnData +from matplotlib import patheffects from matplotlib.cm import ScalarMappable from matplotlib.colors import ListedColormap, Normalize from scanpy._settings import settings as sc_settings +from scanpy.plotting._tools.scatterplots import _add_categorical_legend from spatialdata import get_extent, get_values, join_spatialelement_table from spatialdata._core.query.relational_query import match_table_to_element from spatialdata.models import PointsModel, ShapesModel, get_table_keys @@ -41,6 +44,7 @@ _render_ds_outlines, ) from spatialdata_plot.pl.render_params import ( + ChannelLegendEntry, CmapParams, Color, ColorbarSpec, @@ -185,7 +189,9 @@ def _filter_groups_transparent_na( return keep, filtered_csv, filtered_cv -def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]: +def _split_colorbar_params( + params: dict[str, object] | None, +) -> tuple[dict[str, object], dict[str, object], str | None]: """Split colorbar params into layout hints, Matplotlib kwargs, and label override.""" layout: dict[str, object] = {} cbar_kwargs: dict[str, object] = {} @@ -206,7 +212,10 @@ def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, def _resolve_colorbar_label( - colorbar_params: dict[str, object] | None, fallback: str | None, *, is_default_channel_name: bool = False + colorbar_params: dict[str, object] | None, + fallback: str | None, + *, + is_default_channel_name: bool = False, ) -> str | None: """Pick a colorbar label from params or fall back to provided value.""" _, _, label = _split_colorbar_params(colorbar_params) @@ -366,7 +375,7 @@ def _render_shapes( value_to_plot=col_for_color, groups=groups, palette=render_params.palette, - na_color=render_params.color if render_params.color is not None else render_params.cmap_params.na_color, + na_color=(render_params.color if render_params.color is not None else render_params.cmap_params.na_color), cmap_params=render_params.cmap_params, table_name=table_name, table_layer=table_layer, @@ -440,7 +449,10 @@ def _render_shapes( if not (render_params.shape == "circle" and (current_type == "Point").all()): logger.info(f"Converting {shapes.shape[0]} shapes to {render_params.shape}.") max_extent = np.max( - [shapes.total_bounds[2] - shapes.total_bounds[0], shapes.total_bounds[3] - shapes.total_bounds[1]] + [ + shapes.total_bounds[2] - shapes.total_bounds[0], + shapes.total_bounds[3] - shapes.total_bounds[1], + ] ) shapes = _convert_shapes(shapes, render_params.shape, max_extent) @@ -565,7 +577,15 @@ def _render_shapes( na_color_hex, ) - _render_ds_outlines(cvs, transformed_element, render_params, fig_params, ax, factor, x_ext + y_ext) + _render_ds_outlines( + cvs, + transformed_element, + render_params, + fig_params, + ax, + factor, + x_ext + y_ext, + ) _cax = _render_ds_image( ax, @@ -832,7 +852,13 @@ def _render_points( ) if added_color_from_table and col_for_color is not None: - _reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system) + _reparse_points( + sdata_filt, + element, + points_pd_with_color, + transformation_in_cs, + coordinate_system, + ) _warn_groups_ignored_continuous(groups, color_source_vector, col_for_color) @@ -1094,6 +1120,78 @@ def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]: return False, False +def _collect_channel_legend_entries( + channels: Sequence[str | int], + seed_colors: Sequence[str | tuple[float, ...]], + channel_legend_entries: list[ChannelLegendEntry], +) -> None: + """Accumulate channel-to-color mappings for a deferred combined legend.""" + channel_names = [str(ch) for ch in channels] + if len(set(channel_names)) != len(channel_names): + logger.warning("channels_as_legend: duplicate channel names detected; skipping legend entries.") + return + + color_hexes = [matplotlib.colors.to_hex(c, keep_alpha=False) for c in seed_colors] + for name, color in zip(channel_names, color_hexes, strict=True): + channel_legend_entries.append(ChannelLegendEntry(channel_name=name, color_hex=color)) + + +def _draw_channel_legend( + ax: matplotlib.axes.SubplotBase, + entries: list[ChannelLegendEntry], + legend_params: LegendParams, + fig_params: FigParams, +) -> None: + """Draw a single combined categorical legend from accumulated channel entries. + + Because ``_add_categorical_legend`` adds invisible labeled scatter artists, + calling it here automatically merges with any earlier legend entries + (e.g. from labels or shapes) on the same axes via ``ax.legend()``. + + ``multi_panel`` is only set when no prior legend exists on the axis, + to avoid shrinking the axes twice (once for labels/shapes, once for + channels). + """ + # Deduplicate: if the same channel name appears twice, keep the last color + palette_dict: dict[str, str] = {} + for entry in entries: + palette_dict[entry.channel_name] = entry.color_hex + + legend_loc = legend_params.legend_loc + if legend_loc == "on data": + logger.warning( + "legend_loc='on data' is not supported for channel legends (no scatter coordinates); " + "falling back to 'right margin'." + ) + legend_loc = "right margin" + + categories = pd.Categorical(list(palette_dict)) + + path_effect = ( + [patheffects.withStroke(linewidth=legend_params.legend_fontoutline, foreground="w")] + if legend_params.legend_fontoutline is not None + else [] + ) + + # Only apply multi_panel shrink if no legend already exists on this axis + # (labels/shapes draw their legend during the render loop and already shrink). + has_existing_legend = ax.get_legend() is not None + needs_multi_panel = fig_params.axs is not None and not has_existing_legend + + _add_categorical_legend( + ax, + categories, + palette=palette_dict, + legend_loc=legend_loc, + legend_fontweight=legend_params.legend_fontweight, + legend_fontsize=legend_params.legend_fontsize, + legend_fontoutline=path_effect, + na_color=["lightgray"], + na_in_legend=False, + multi_panel=needs_multi_panel, + ) + + def _render_images( sdata: sd.SpatialData, render_params: ImageRenderParams, @@ -1104,6 +1202,7 @@ def _render_images( legend_params: LegendParams, rasterize: bool, colorbar_requests: list[ColorbarSpec] | None = None, + channel_legend_entries: list[ChannelLegendEntry] | None = None, ) -> None: _log_context.set("render_images") sdata_filt = sdata.filter_by_coordinate_system( @@ -1325,10 +1424,14 @@ def _render_images( layers[ch] = ch_norm(layers[ch]) + # Colors for the channel legend (set by each branch if applicable) + legend_colors: list[str] | None = None + # 2A) Image has 3 channels, no palette info, and no/only one cmap was given if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list): if render_params.cmap_params.cmap_is_default: # -> use RGB stacked = np.clip(np.stack([layers[ch] for ch in layers], axis=-1), 0, 1) + legend_colors = ["red", "green", "blue"] else: # -> use given cmap for each channel channel_cmaps = [render_params.cmap_params.cmap] * n_channels stacked = ( @@ -1410,6 +1513,8 @@ def _render_images( f"multichannel strategy 'stack' to render." ) # TODO: update when pca is added as strategy + legend_colors = seed_colors + _ax_show_and_transform( colored, trans_data, @@ -1427,6 +1532,8 @@ def _render_images( colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) colored = np.clip(colored[:, :, :3], 0, 1) + legend_colors = list(palette) + _ax_show_and_transform( colored, trans_data, @@ -1446,6 +1553,8 @@ def _render_images( ) colored = colored[:, :, :3] + legend_colors = [matplotlib.colors.to_hex(cm(0.75)) for cm in channel_cmaps] + _ax_show_and_transform( colored, trans_data, @@ -1458,6 +1567,17 @@ def _render_images( elif palette is not None and got_multiple_cmaps: raise ValueError("If 'palette' is provided, 'cmap' must be None.") + # Collect channel legend entries (single point for all multi-channel paths) + if render_params.channels_as_legend and channel_legend_entries is not None: + if legend_colors is not None: + _collect_channel_legend_entries(channels, legend_colors, channel_legend_entries) + else: + logger.warning( + "channels_as_legend requires distinct per-channel colors; " + "ignored when a single cmap is shared across channels. " + "Use 'palette' or a list of cmaps instead." + ) + def _render_labels( sdata: sd.SpatialData, diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 7cbc68f5..16f81578 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -37,7 +37,9 @@ class Color: user_defined_alpha: bool = False def __init__( - self, color: None | str | list[float] | tuple[float, ...] = "default", alpha: float | int | None = None + self, + color: None | str | list[float] | tuple[float, ...] = "default", + alpha: float | int | None = None, ) -> None: # 1) Validate alpha value if alpha is None: @@ -199,6 +201,14 @@ class ColorbarSpec: alpha: float | None = None +@dataclass +class ChannelLegendEntry: + """A single channel-to-color mapping for the categorical channel legend.""" + + channel_name: str + color_hex: str + + CBAR_DEFAULT_LOCATION = "right" CBAR_DEFAULT_FRACTION = 0.075 CBAR_DEFAULT_PAD = 0.015 @@ -274,6 +284,7 @@ class ImageRenderParams: colorbar_params: dict[str, object] | None = None transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None grayscale: bool = False + channels_as_legend: bool = False @dataclass diff --git a/tests/_images/ChannelsAsCategories_channels_as_legend_legend_lower_right.png b/tests/_images/ChannelsAsCategories_channels_as_legend_legend_lower_right.png new file mode 100644 index 00000000..df57f52c Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_legend_legend_lower_right.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_legend_legend_upper_left.png b/tests/_images/ChannelsAsCategories_channels_as_legend_legend_upper_left.png new file mode 100644 index 00000000..79502b90 Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_legend_legend_upper_left.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_legend_many_channels.png b/tests/_images/ChannelsAsCategories_channels_as_legend_many_channels.png new file mode 100644 index 00000000..3332f0fe Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_legend_many_channels.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_legend_three_channels_default.png b/tests/_images/ChannelsAsCategories_channels_as_legend_three_channels_default.png new file mode 100644 index 00000000..537aa53e Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_legend_three_channels_default.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_legend_two_channels.png b/tests/_images/ChannelsAsCategories_channels_as_legend_two_channels.png new file mode 100644 index 00000000..b2a0829d Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_legend_two_channels.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_legend_with_cmap_list.png b/tests/_images/ChannelsAsCategories_channels_as_legend_with_cmap_list.png new file mode 100644 index 00000000..41a9cc62 Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_legend_with_cmap_list.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_legend_with_palette.png b/tests/_images/ChannelsAsCategories_channels_as_legend_with_palette.png new file mode 100644 index 00000000..1c018591 Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_legend_with_palette.png differ diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 0bae024b..03e21a8e 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -42,7 +42,11 @@ def test_plot_can_pass_str_cmap_list(self, sdata_blobs: SpatialData): def test_plot_can_pass_cmap_list(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images( element="blobs_image", - cmap=[matplotlib.colormaps["seismic"], matplotlib.colormaps["Reds"], matplotlib.colormaps["Blues"]], + cmap=[ + matplotlib.colormaps["seismic"], + matplotlib.colormaps["Reds"], + matplotlib.colormaps["Blues"], + ], ).pl.show() def test_plot_can_render_a_single_channel_from_image(self, sdata_blobs: SpatialData): @@ -491,3 +495,105 @@ def test_cmap_matches_selected_channels_not_full_image(sdata_blobs: SpatialData) sdata_blobs.pl.render_images("blobs_image", channel=[0], cmap=["gray"]).pl.show(ax=ax) assert len(ax.get_images()) == 1 plt.close(fig) + + +# --------------------------------------------------------------------------- +# channels_as_legend visual tests (#459) +# --------------------------------------------------------------------------- + + +class TestChannelsAsCategories(PlotTester, metaclass=PlotTesterMeta): + def test_plot_channels_as_legend_two_channels(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_legend=True).pl.show() + + def test_plot_channels_as_legend_three_channels_default(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channels_as_legend=True).pl.show() + + def test_plot_channels_as_legend_with_palette(self, sdata_blobs_str: SpatialData): + sdata_blobs_str.pl.render_images( + element="blobs_image", + channel=["c1", "c2", "c3"], + palette=["red", "green", "blue"], + channels_as_legend=True, + ).pl.show() + + def test_plot_channels_as_legend_many_channels(self, sdata_blobs_str: SpatialData): + sdata_blobs_str.pl.render_images(element="blobs_image", channels_as_legend=True).pl.show() + + def test_plot_channels_as_legend_with_cmap_list(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images( + element="blobs_image", + channel=[0, 1, 2], + cmap=["Reds", "Greens", "Blues"], + channels_as_legend=True, + ).pl.show() + + def test_plot_channels_as_legend_legend_upper_left(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_legend=True).pl.show( + legend_loc="upper left" + ) + + def test_plot_channels_as_legend_legend_lower_right(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_legend=True).pl.show( + legend_loc="lower right" + ) + + +class TestChannelsAsCategoriesNonVisual: + """Non-visual tests for channels_as_legend edge cases.""" + + def test_channels_as_legend_ignored_for_single_channel(self, sdata_blobs: SpatialData): + fig, ax = plt.subplots() + sdata_blobs.pl.render_images(element="blobs_image", channel=0, channels_as_legend=True).pl.show(ax=ax) + assert ax.get_legend() is None + plt.close("all") + + def test_channels_as_legend_false_no_legend(self, sdata_blobs: SpatialData): + fig, ax = plt.subplots() + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_legend=False).pl.show(ax=ax) + assert ax.get_legend() is None + plt.close("all") + + def test_channels_as_legend_chained_renders_combine(self, sdata_blobs: SpatialData): + """Multiple render_images with channels_as_legend should produce one combined legend.""" + fig, ax = plt.subplots() + ( + sdata_blobs.pl.render_images( + element="blobs_image", + channel=[0, 1], + palette=["red", "green"], + channels_as_legend=True, + ) + .pl.render_images( + element="blobs_image", + channel=[1, 2], + palette=["cyan", "blue"], + channels_as_legend=True, + ) + .pl.show(ax=ax) + ) + legend = ax.get_legend() + assert legend is not None + labels = [t.get_text() for t in legend.get_texts()] + # Both render calls contribute: channels 0, 1, 2. + # Channel "1" appears in both calls — dedup keeps the last color. + assert "0" in labels + assert "1" in labels + assert "2" in labels + assert len(labels) == 3 + plt.close("all") + + def test_channels_as_legend_coexists_with_other_elements(self, sdata_blobs: SpatialData): + """Channel legend should not crash when combined with other render calls.""" + fig, ax = plt.subplots() + ( + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_legend=True) + .pl.render_labels(element="blobs_labels") + .pl.show(ax=ax) + ) + legend = ax.get_legend() + assert legend is not None + labels = [t.get_text() for t in legend.get_texts()] + assert "0" in labels + assert "1" in labels + plt.close("all")