Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from spatialdata_plot._accessor import register_spatial_data_accessor
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot.pl.render import (
_draw_channel_legend,
_render_images,
_render_labels,
_render_points,
Expand All @@ -40,6 +41,7 @@
CBAR_DEFAULT_FRACTION,
CBAR_DEFAULT_LOCATION,
CBAR_DEFAULT_PAD,
ChannelLegendEntry,
CmapParams,
ColorbarSpec,
ImageRenderParams,
Expand Down Expand Up @@ -532,9 +534,10 @@ def render_images(
alpha: float | int = 1.0,
scale: str | None = None,
grayscale: bool = False,
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None,
transfunc: (Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None) = None,
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
channels_as_legend: bool = False,
) -> sd.SpatialData:
"""
Render image elements in SpatialData.
Expand Down Expand Up @@ -608,6 +611,13 @@ def render_images(
colorbar_params : dict[str, object] | None
Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
and ``label``.
channels_as_legend : bool, default False
When ``True`` and rendering multiple channels, show a categorical
legend mapping each channel name to its compositing color. The
legend uses the ``legend_*`` parameters from :meth:`show`.
Ignored for single-channel and RGB(A) images. When multiple
``render_images`` calls use this flag on the same axes, all
channel entries are combined into a single legend.

Notes
-----
Expand Down Expand Up @@ -690,6 +700,7 @@ def render_images(
colorbar_params=param_values["colorbar_params"],
transfunc=transfunc,
grayscale=grayscale,
channels_as_legend=channels_as_legend,
)
n_steps += 1

Expand Down Expand Up @@ -1194,6 +1205,7 @@ def _draw_colorbar(
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
assert isinstance(ax, Axes)
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None
axis_channel_legend_entries: list[ChannelLegendEntry] = []

wants_images = False
wants_labels = False
Expand Down Expand Up @@ -1224,6 +1236,7 @@ def _draw_colorbar(
scalebar_params=scalebar_params,
legend_params=legend_params,
colorbar_requests=axis_colorbar_requests,
channel_legend_entries=axis_channel_legend_entries,
rasterize=rasterize,
)

Expand Down Expand Up @@ -1270,7 +1283,10 @@ def _draw_colorbar(
table = params_copy.table_name
if table is not None and params_copy.col_for_color is not None:
colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color])
if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype):
if isinstance(
colors[params_copy.col_for_color].dtype,
pd.CategoricalDtype,
):
_maybe_set_colors(
source=sdata[table],
target=sdata[table],
Expand Down Expand Up @@ -1333,6 +1349,9 @@ def _draw_colorbar(
if legend_params.colorbar and axis_colorbar_requests:
pending_colorbars.append((ax, axis_colorbar_requests))

if axis_channel_legend_entries:
_draw_channel_legend(ax, axis_channel_legend_entries, legend_params, fig_params)

if pending_colorbars and fig_params.fig is not None:
fig = fig_params.fig
fig.canvas.draw()
Expand Down
132 changes: 126 additions & 6 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
from collections import abc
from collections.abc import Sequence
from copy import copy
from typing import Any

Expand All @@ -18,9 +19,11 @@
import spatialdata as sd
import xarray as xr
from anndata import AnnData
from matplotlib import patheffects
from matplotlib.cm import ScalarMappable
from matplotlib.colors import ListedColormap, Normalize
from scanpy._settings import settings as sc_settings
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
from spatialdata import get_extent, get_values, join_spatialelement_table
from spatialdata._core.query.relational_query import match_table_to_element
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
Expand All @@ -41,6 +44,7 @@
_render_ds_outlines,
)
from spatialdata_plot.pl.render_params import (
ChannelLegendEntry,
CmapParams,
Color,
ColorbarSpec,
Expand Down Expand Up @@ -185,7 +189,9 @@ def _filter_groups_transparent_na(
return keep, filtered_csv, filtered_cv


def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]:
def _split_colorbar_params(
params: dict[str, object] | None,
) -> tuple[dict[str, object], dict[str, object], str | None]:
"""Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
layout: dict[str, object] = {}
cbar_kwargs: dict[str, object] = {}
Expand All @@ -206,7 +212,10 @@ def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str,


def _resolve_colorbar_label(
colorbar_params: dict[str, object] | None, fallback: str | None, *, is_default_channel_name: bool = False
colorbar_params: dict[str, object] | None,
fallback: str | None,
*,
is_default_channel_name: bool = False,
) -> str | None:
"""Pick a colorbar label from params or fall back to provided value."""
_, _, label = _split_colorbar_params(colorbar_params)
Expand Down Expand Up @@ -366,7 +375,7 @@ def _render_shapes(
value_to_plot=col_for_color,
groups=groups,
palette=render_params.palette,
na_color=render_params.color if render_params.color is not None else render_params.cmap_params.na_color,
na_color=(render_params.color if render_params.color is not None else render_params.cmap_params.na_color),
cmap_params=render_params.cmap_params,
table_name=table_name,
table_layer=table_layer,
Expand Down Expand Up @@ -440,7 +449,10 @@ def _render_shapes(
if not (render_params.shape == "circle" and (current_type == "Point").all()):
logger.info(f"Converting {shapes.shape[0]} shapes to {render_params.shape}.")
max_extent = np.max(
[shapes.total_bounds[2] - shapes.total_bounds[0], shapes.total_bounds[3] - shapes.total_bounds[1]]
[
shapes.total_bounds[2] - shapes.total_bounds[0],
shapes.total_bounds[3] - shapes.total_bounds[1],
]
)
shapes = _convert_shapes(shapes, render_params.shape, max_extent)

Expand Down Expand Up @@ -565,7 +577,15 @@ def _render_shapes(
na_color_hex,
)

_render_ds_outlines(cvs, transformed_element, render_params, fig_params, ax, factor, x_ext + y_ext)
_render_ds_outlines(
cvs,
transformed_element,
render_params,
fig_params,
ax,
factor,
x_ext + y_ext,
)

_cax = _render_ds_image(
ax,
Expand Down Expand Up @@ -832,7 +852,13 @@ def _render_points(
)

if added_color_from_table and col_for_color is not None:
_reparse_points(sdata_filt, element, points_pd_with_color, transformation_in_cs, coordinate_system)
_reparse_points(
sdata_filt,
element,
points_pd_with_color,
transformation_in_cs,
coordinate_system,
)

_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)

Expand Down Expand Up @@ -1094,6 +1120,78 @@ def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]:
return False, False


def _collect_channel_legend_entries(
channels: Sequence[str | int],
seed_colors: Sequence[str | tuple[float, ...]],
channel_legend_entries: list[ChannelLegendEntry],
) -> None:
"""Accumulate channel-to-color mappings for a deferred combined legend."""
channel_names = [str(ch) for ch in channels]
if len(set(channel_names)) != len(channel_names):
logger.warning("channels_as_legend: duplicate channel names detected; skipping legend entries.")
return

color_hexes = [matplotlib.colors.to_hex(c, keep_alpha=False) for c in seed_colors]
for name, color in zip(channel_names, color_hexes, strict=True):
channel_legend_entries.append(ChannelLegendEntry(channel_name=name, color_hex=color))


def _draw_channel_legend(
ax: matplotlib.axes.SubplotBase,
entries: list[ChannelLegendEntry],
legend_params: LegendParams,
fig_params: FigParams,
) -> None:
"""Draw a single combined categorical legend from accumulated channel entries.

Because ``_add_categorical_legend`` adds invisible labeled scatter artists,
calling it here automatically merges with any earlier legend entries
(e.g. from labels or shapes) on the same axes via ``ax.legend()``.

``multi_panel`` is only set when no prior legend exists on the axis,
to avoid shrinking the axes twice (once for labels/shapes, once for
channels).
"""
# Deduplicate: if the same channel name appears twice, keep the last color
palette_dict: dict[str, str] = {}
for entry in entries:
palette_dict[entry.channel_name] = entry.color_hex

legend_loc = legend_params.legend_loc
if legend_loc == "on data":
logger.warning(
"legend_loc='on data' is not supported for channel legends (no scatter coordinates); "
"falling back to 'right margin'."
)
legend_loc = "right margin"

categories = pd.Categorical(list(palette_dict))

path_effect = (
[patheffects.withStroke(linewidth=legend_params.legend_fontoutline, foreground="w")]
if legend_params.legend_fontoutline is not None
else []
)

# Only apply multi_panel shrink if no legend already exists on this axis
# (labels/shapes draw their legend during the render loop and already shrink).
has_existing_legend = ax.get_legend() is not None
needs_multi_panel = fig_params.axs is not None and not has_existing_legend

_add_categorical_legend(
ax,
categories,
palette=palette_dict,
legend_loc=legend_loc,
legend_fontweight=legend_params.legend_fontweight,
legend_fontsize=legend_params.legend_fontsize,
legend_fontoutline=path_effect,
na_color=["lightgray"],
na_in_legend=False,
multi_panel=needs_multi_panel,
)


def _render_images(
sdata: sd.SpatialData,
render_params: ImageRenderParams,
Expand All @@ -1104,6 +1202,7 @@ def _render_images(
legend_params: LegendParams,
rasterize: bool,
colorbar_requests: list[ColorbarSpec] | None = None,
channel_legend_entries: list[ChannelLegendEntry] | None = None,
) -> None:
_log_context.set("render_images")
sdata_filt = sdata.filter_by_coordinate_system(
Expand Down Expand Up @@ -1325,10 +1424,14 @@ def _render_images(

layers[ch] = ch_norm(layers[ch])

# Colors for the channel legend (set by each branch if applicable)
legend_colors: list[str] | None = None

# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
if render_params.cmap_params.cmap_is_default: # -> use RGB
stacked = np.clip(np.stack([layers[ch] for ch in layers], axis=-1), 0, 1)
legend_colors = ["red", "green", "blue"]
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
stacked = (
Expand Down Expand Up @@ -1410,6 +1513,8 @@ def _render_images(
f"multichannel strategy 'stack' to render."
) # TODO: update when pca is added as strategy

legend_colors = seed_colors

_ax_show_and_transform(
colored,
trans_data,
Expand All @@ -1427,6 +1532,8 @@ def _render_images(
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
colored = np.clip(colored[:, :, :3], 0, 1)

legend_colors = list(palette)

_ax_show_and_transform(
colored,
trans_data,
Expand All @@ -1446,6 +1553,8 @@ def _render_images(
)
colored = colored[:, :, :3]

legend_colors = [matplotlib.colors.to_hex(cm(0.75)) for cm in channel_cmaps]

_ax_show_and_transform(
colored,
trans_data,
Expand All @@ -1458,6 +1567,17 @@ def _render_images(
elif palette is not None and got_multiple_cmaps:
raise ValueError("If 'palette' is provided, 'cmap' must be None.")

# Collect channel legend entries (single point for all multi-channel paths)
if render_params.channels_as_legend and channel_legend_entries is not None:
if legend_colors is not None:
_collect_channel_legend_entries(channels, legend_colors, channel_legend_entries)
else:
logger.warning(
"channels_as_legend requires distinct per-channel colors; "
"ignored when a single cmap is shared across channels. "
"Use 'palette' or a list of cmaps instead."
)


def _render_labels(
sdata: sd.SpatialData,
Expand Down
13 changes: 12 additions & 1 deletion src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class Color:
user_defined_alpha: bool = False

def __init__(
self, color: None | str | list[float] | tuple[float, ...] = "default", alpha: float | int | None = None
self,
color: None | str | list[float] | tuple[float, ...] = "default",
alpha: float | int | None = None,
) -> None:
# 1) Validate alpha value
if alpha is None:
Expand Down Expand Up @@ -199,6 +201,14 @@ class ColorbarSpec:
alpha: float | None = None


@dataclass
class ChannelLegendEntry:
"""A single channel-to-color mapping for the categorical channel legend."""

channel_name: str
color_hex: str


CBAR_DEFAULT_LOCATION = "right"
CBAR_DEFAULT_FRACTION = 0.075
CBAR_DEFAULT_PAD = 0.015
Expand Down Expand Up @@ -274,6 +284,7 @@ class ImageRenderParams:
colorbar_params: dict[str, object] | None = None
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None
grayscale: bool = False
channels_as_legend: bool = False


@dataclass
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading