diff --git a/src/spatialdata_plot/pl/__init__.py b/src/spatialdata_plot/pl/__init__.py index 8bf47aa9..178b4914 100644 --- a/src/spatialdata_plot/pl/__init__.py +++ b/src/spatialdata_plot/pl/__init__.py @@ -1,5 +1,8 @@ +from ._palette import make_palette, make_palette_from_data from .basic import PlotAccessor __all__ = [ "PlotAccessor", + "make_palette", + "make_palette_from_data", ] diff --git a/src/spatialdata_plot/pl/_palette.py b/src/spatialdata_plot/pl/_palette.py new file mode 100644 index 00000000..9f700181 --- /dev/null +++ b/src/spatialdata_plot/pl/_palette.py @@ -0,0 +1,697 @@ +"""Palette generation utilities. + +Two public functions: + +- :func:`make_palette` — produce *n* colours, optionally reordered for + maximum perceptual contrast or colourblind accessibility. +- :func:`make_palette_from_data` — like :func:`make_palette` but derives + the number of colours and (for ``spaco`` methods) the assignment order + from a :class:`~spatialdata.SpatialData` element. + +Both share the same *palette* / *method* vocabulary. The *palette* +parameter controls **which** colours are used (the source), while +*method* controls **how** they are ordered or assigned. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import numpy as np +import pandas as pd +from matplotlib.colors import ListedColormap, to_hex, to_rgb +from matplotlib.pyplot import colormaps as mpl_colormaps +from scanpy.plotting.palettes import default_20, default_28, default_102 +from scipy.spatial import cKDTree + +from spatialdata_plot._logging import logger + +if TYPE_CHECKING: + from collections.abc import Sequence + + import spatialdata as sd + +# --------------------------------------------------------------------------- +# Built-in named palettes +# --------------------------------------------------------------------------- + +# Okabe & Ito (2008) — designed for universal colour-vision accessibility. +# Hex values from https://jfly.uni-koeln.de/color/ +_OKABE_ITO: list[str] = [ + "#E69F00", # orange + "#56B4E9", # sky blue + "#009E73", # bluish green + "#F0E442", # yellow + "#0072B2", # blue + "#D55E00", # vermillion + "#CC79A7", # reddish purple + "#000000", # black +] + +_NAMED_PALETTES: dict[str, list[str]] = { + "okabe_ito": _OKABE_ITO, +} + +# --------------------------------------------------------------------------- +# Color-space helpers +# --------------------------------------------------------------------------- + +# Oklab conversion (Björn Ottosson, public domain) +# https://bottosson.github.io/posts/oklab/ + + +def _srgb_to_linear(c: np.ndarray) -> np.ndarray: + """SRGB [0,1] → linear RGB.""" + return np.where(c <= 0.04045, c / 12.92, ((c + 0.055) / 1.055) ** 2.4) + + +def _linear_to_srgb(c: np.ndarray) -> np.ndarray: + """Linear RGB → sRGB [0,1].""" + return np.where(c <= 0.0031308, 12.92 * c, 1.055 * c ** (1.0 / 2.4) - 0.055) + + +def _rgb_to_oklab(rgb: np.ndarray) -> np.ndarray: + """Convert Nx3 sRGB [0,1] array to Oklab.""" + lin = _srgb_to_linear(rgb) + l = 0.4122214708 * lin[:, 0] + 0.5363325363 * lin[:, 1] + 0.0514459929 * lin[:, 2] + m = 0.2119034982 * lin[:, 0] + 0.6806995451 * lin[:, 1] + 0.1073969566 * lin[:, 2] + s = 0.0883024619 * lin[:, 0] + 0.2817188376 * lin[:, 1] + 0.6299787005 * lin[:, 2] + l_ = np.cbrt(l) + m_ = np.cbrt(m) + s_ = np.cbrt(s) + return np.column_stack( + [ + 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_, + 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_, + 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_, + ] + ) + + +# --------------------------------------------------------------------------- +# Color-vision-deficiency simulation (Brettel, Viénot & Mollon 1997) +# --------------------------------------------------------------------------- + +# Simulation matrices for dichromacy in linear RGB space. +# Source: libDaltonLens / DaltonLens-Python (MIT licensed constants). +_CVD_MATRICES: dict[str, np.ndarray] = { + "protanopia": np.array( + [[0.152286, 1.052583, -0.204868], [0.114503, 0.786281, 0.099216], [-0.003882, -0.048116, 1.051998]] + ), + "deuteranopia": np.array( + [[0.367322, 0.860646, -0.227968], [0.280085, 0.672501, 0.047413], [-0.011820, 0.042940, 0.968881]] + ), + "tritanopia": np.array( + [[-0.006540, 0.975530, 0.031010], [0.016270, 0.943972, 0.039758], [-0.244708, 0.759930, 0.484778]] + ), +} + + +def _simulate_cvd(rgb: np.ndarray, cvd_type: str) -> np.ndarray: + """Simulate color vision deficiency on Nx3 sRGB [0,1] array. + + For ``"general"``, returns the element-wise minimum distinctness across + all three deficiency types (worst-case). + """ + if cvd_type == "general": + return np.stack([_simulate_cvd(rgb, t) for t in ("protanopia", "deuteranopia", "tritanopia")]) + + mat = _CVD_MATRICES[cvd_type] + lin = _srgb_to_linear(rgb) + sim = lin @ mat.T + return np.clip(_linear_to_srgb(np.clip(sim, 0, 1)), 0, 1) # type: ignore[no-any-return] + + +# --------------------------------------------------------------------------- +# Shared optimization core +# --------------------------------------------------------------------------- + + +def _perceptual_distance_matrix( + rgb: np.ndarray, + colorblind_type: str | None = None, +) -> np.ndarray: + """Pairwise Oklab Euclidean distance between colors. + + If *colorblind_type* is set, distances are computed on CVD-simulated + colors. For ``"general"``, the minimum distance across all three + deficiency types is used (worst-case optimization). + """ + if colorblind_type is not None: + sim = _simulate_cvd(rgb, colorblind_type) + if colorblind_type == "general": + mats = [_pairwise_oklab_dist(_rgb_to_oklab(s)) for s in sim] + return np.minimum.reduce(mats) # type: ignore[no-any-return] + rgb = sim + + lab = _rgb_to_oklab(rgb) + return _pairwise_oklab_dist(lab) + + +def _pairwise_oklab_dist(lab: np.ndarray) -> np.ndarray: + """Pairwise Euclidean distance in Oklab space.""" + diff = lab[:, np.newaxis, :] - lab[np.newaxis, :, :] + return np.sqrt((diff**2).sum(axis=-1)) # type: ignore[no-any-return] + + +def _optimize_assignment( + weight_matrix: np.ndarray, + color_dist: np.ndarray, + n_random: int = 5000, + n_swaps: int = 10000, + rng: np.random.Generator | None = None, +) -> np.ndarray: + """Find a permutation that maximizes ``sum(weights * color_dist[perm, perm])``. + + Works for both spatial interlacement weights (spaco) and uniform + weights (pure contrast maximization). + + Returns an index array: ``perm[category_idx] = color_idx``. + """ + if rng is None: + rng = np.random.default_rng() + + n = weight_matrix.shape[0] + if n <= 2: + # For n<=2 there are at most 2 permutations; just try both. + if n <= 1: + return np.arange(n) + id_perm = np.arange(n) + sw_perm = np.array([1, 0]) + s_id = float(np.sum(weight_matrix * color_dist[np.ix_(id_perm, id_perm)])) + s_sw = float(np.sum(weight_matrix * color_dist[np.ix_(sw_perm, sw_perm)])) + return sw_perm if s_sw > s_id else id_perm + + def _score(perm: np.ndarray) -> float: + return float(np.sum(weight_matrix * color_dist[np.ix_(perm, perm)])) + + best_perm = np.arange(n) + best_score = _score(best_perm) + + for _ in range(n_random): + perm = rng.permutation(n) + s = _score(perm) + if s > best_score: + best_score = s + best_perm = perm.copy() + + for _ in range(n_swaps): + i, j = rng.integers(0, n, size=2) + if i == j: + continue + best_perm[i], best_perm[j] = best_perm[j], best_perm[i] + s = _score(best_perm) + if s > best_score: + best_score = s + else: + best_perm[i], best_perm[j] = best_perm[j], best_perm[i] + + return best_perm + + +def _optimized_order( + colors_list: list[str], + *, + colorblind_type: str | None = None, + n_random: int = 5000, + n_swaps: int = 10000, + seed: int = 0, +) -> list[str]: + """Reorder *colors_list* to maximize pairwise perceptual spread.""" + n = len(colors_list) + if n <= 2: + return colors_list + + rgb = np.array([to_rgb(c) for c in colors_list]) + cdist = _perceptual_distance_matrix(rgb, colorblind_type=colorblind_type) + + # Uniform weight matrix: all off-diagonal pairs equally important + weights = np.ones((n, n)) - np.eye(n) + + rng = np.random.default_rng(seed) + perm = _optimize_assignment(weights, cdist, n_random=n_random, n_swaps=n_swaps, rng=rng) + return [to_hex(rgb[perm[i]]) for i in range(n)] + + +# --------------------------------------------------------------------------- +# Spatial interlacement (spaco-specific) +# --------------------------------------------------------------------------- + + +def _spatial_interlacement( + coords: np.ndarray, + labels: np.ndarray, + categories: Sequence[str], + n_neighbors: int = 15, +) -> np.ndarray: + """Build a symmetric interlacement matrix (n_categories × n_categories). + + Entry (i, j) reflects how much categories i and j are spatially + interleaved, measured by inverse-distance-weighted neighbor counts. + """ + n_cat = len(categories) + cat_to_idx = {c: i for i, c in enumerate(categories)} + label_idx = np.array([cat_to_idx[l] for l in labels]) + + tree = cKDTree(coords) + dists, indices = tree.query(coords, k=min(n_neighbors + 1, len(coords))) + + # Vectorized accumulation (avoids Python double-loop over cells × neighbors) + neighbor_dists = dists[:, 1:] + neighbor_indices = indices[:, 1:] + cell_cats = label_idx + neighbor_cats = label_idx[neighbor_indices] + + # Mask: different category and positive distance + cross_cat = neighbor_cats != cell_cats[:, np.newaxis] + valid_dist = neighbor_dists > 0 + mask = cross_cat & valid_dist + + weights = np.where(mask, 1.0 / np.where(neighbor_dists > 0, neighbor_dists, 1.0), 0.0) + + rows = np.broadcast_to(cell_cats[:, np.newaxis], neighbor_cats.shape)[mask] + cols = neighbor_cats[mask] + vals = weights[mask] + + mat = np.zeros((n_cat, n_cat), dtype=np.float64) + np.add.at(mat, (rows, cols), vals) + + mat = np.maximum(mat, mat.T) + max_val = mat.max() + if max_val > 0: + mat /= max_val + return mat # type: ignore[no-any-return] + + +# --------------------------------------------------------------------------- +# Palette resolution +# --------------------------------------------------------------------------- + + +def _resolve_palette(palette: list[str] | str | None, n: int) -> list[str]: + """Resolve *n* colours from an explicit list, a named palette, or scanpy defaults.""" + if isinstance(palette, list): + if len(palette) < n: + raise ValueError(f"Palette has {len(palette)} colors but {n} are needed.") + return list(palette[:n]) + + if isinstance(palette, str): + if palette in _NAMED_PALETTES: + colors = _NAMED_PALETTES[palette] + if len(colors) < n: + raise ValueError( + f"Named palette '{palette}' has {len(colors)} colors but {n} are needed. " + f"Please provide a palette with at least {n} colors." + ) + return list(colors[:n]) + + if palette in mpl_colormaps: + cmap = mpl_colormaps[palette] + if isinstance(cmap, ListedColormap): + # Qualitative colormaps (tab10, Set1, etc.): sample by index + if n > cmap.N: + raise ValueError(f"Colormap '{palette}' has {cmap.N} colors but {n} are needed.") + return [to_hex(cmap(i)) for i in range(n)] + indices = np.linspace(0, 1, n) + return [to_hex(cmap(i)) for i in indices] + + raise ValueError( + f"Unknown palette name '{palette}'. Use a list of colors, a matplotlib colormap name, " + f"or one of: {', '.join(sorted(_NAMED_PALETTES))}." + ) + + # palette is None — use scanpy defaults + if n <= 20: + return list(default_20[:n]) + if n <= 28: + return list(default_28[:n]) + if n <= len(default_102): + return list(default_102[:n]) + + raise ValueError( + f"{n} colors needed but no palette was provided and the default palette only has " + f"{len(default_102)} colors. Please provide a palette." + ) + + +def _resolve_element( + sdata: sd.SpatialData, + element: str, + color: str, + table_name: str | None = None, +) -> tuple[np.ndarray, pd.Categorical]: + """Extract coordinates and categorical labels from a SpatialData element. + + Coordinates come from the element geometry (shapes) or x/y columns + (points). Labels come from a column on the element itself, or from + a linked table (joined on the instance key to guarantee alignment). + """ + if element in sdata.shapes: + gdf = sdata.shapes[element] + coords = np.column_stack([gdf.geometry.centroid.x, gdf.geometry.centroid.y]) + if color in gdf.columns: + labels_series = gdf[color] + else: + labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name) + # Align coords to table rows via matched instance indices + coords = coords[matched_indices] + elif element in sdata.points: + ddf = sdata.points[element] + if "x" not in ddf.columns or "y" not in ddf.columns: + raise ValueError(f"Points element '{element}' does not have 'x' and 'y' columns.") + if color in ddf.columns: + df = ddf[["x", "y", color]].compute() + coords = df[["x", "y"]].values.astype(np.float64) + labels_series = df[color] + else: + df = ddf[["x", "y"]].compute() + coords = df[["x", "y"]].values.astype(np.float64) + labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name) + coords = coords[matched_indices] + else: + available = list(sdata.shapes.keys()) + list(sdata.points.keys()) + raise KeyError( + f"Element '{element}' not found in sdata.shapes or sdata.points. " + f"Available elements: {available}. Note: labels (raster) elements are not yet supported." + ) + + is_categorical = isinstance(getattr(labels_series, "dtype", None), pd.CategoricalDtype) + labels_cat = labels_series.values if is_categorical else pd.Categorical(labels_series) + return coords, labels_cat + + +def _get_labels_from_table( + sdata: sd.SpatialData, + element: str, + color: str, + table_name: str | None = None, +) -> tuple[pd.Series, np.ndarray]: + """Extract a column from the table linked to an element. + + Returns (labels_series, element_indices) where element_indices maps + each table row to its position in the element, ensuring coord-label + alignment. + """ + from spatialdata.models import get_table_keys + + matches: list[str] = [] + for name in sdata.tables: + table = sdata.tables[name] + region = table.uns.get("spatialdata_attrs", {}).get("region") + if region is not None: + regions = [region] if isinstance(region, str) else region + if element in regions and color in table.obs.columns: + matches.append(name) + + if not matches: + raise KeyError( + f"Column '{color}' not found for element '{element}'. Looked in the element itself and all linked tables." + ) + + if table_name is not None: + if table_name not in matches: + raise KeyError( + f"Table '{table_name}' does not annotate element '{element}' or does not contain column '{color}'." + ) + resolved_name = table_name + elif len(matches) == 1: + resolved_name = matches[0] + else: + raise ValueError( + f"Multiple tables annotate element '{element}' with column '{color}': {matches}. " + f"Please specify table_name= to disambiguate." + ) + + table = sdata.tables[resolved_name] + _, _, instance_key = get_table_keys(table) + + # Join on instance key to align table rows with element positions + instance_ids = table.obs[instance_key].values + element_index = sdata.shapes[element].index if element in sdata.shapes else sdata.points[element].compute().index + + # Map each table instance_id to its position in the element index + element_idx_map = {val: i for i, val in enumerate(element_index)} + matched_indices = [] + valid_mask = [] + for iid in instance_ids: + if iid in element_idx_map: + matched_indices.append(element_idx_map[iid]) + valid_mask.append(True) + else: + valid_mask.append(False) + + valid_mask_arr = np.array(valid_mask) + if not any(valid_mask): + raise ValueError(f"No matching instance keys between table '{resolved_name}' and element '{element}'.") + + labels = table.obs.loc[valid_mask_arr, color] + return labels.reset_index(drop=True), np.array(matched_indices) + + +# --------------------------------------------------------------------------- +# Method lookup tables +# --------------------------------------------------------------------------- + +# Maps non-spatial contrast methods → CVD type (None = normal vision). +_CONTRAST_CVD_TYPES: dict[str, str | None] = { + "contrast": None, + "colorblind": "general", + "protanopia": "protanopia", + "deuteranopia": "deuteranopia", + "tritanopia": "tritanopia", +} + +# Maps spaco methods → CVD type (None = normal vision). +_SPACO_CVD_TYPES: dict[str, str | None] = { + "spaco": None, + "spaco_colorblind": "general", + "spaco_protanopia": "protanopia", + "spaco_deuteranopia": "deuteranopia", + "spaco_tritanopia": "tritanopia", +} + +_ALL_METHODS = sorted({"default", *_CONTRAST_CVD_TYPES, *_SPACO_CVD_TYPES}) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +Method = Literal[ + "default", + "contrast", + "colorblind", + "protanopia", + "deuteranopia", + "tritanopia", + "spaco", + "spaco_colorblind", + "spaco_protanopia", + "spaco_deuteranopia", + "spaco_tritanopia", +] + + +def make_palette( + n: int, + *, + palette: list[str] | str | None = None, + method: Method = "default", + n_random: int = 5000, + n_swaps: int = 10000, + seed: int = 0, +) -> list[str]: + """Generate a list of *n* colours. + + The *palette* parameter controls **which** colours are sampled, while + *method* controls **how** they are ordered. + + Parameters + ---------- + n + Number of colours to produce. + palette + Source colours. Can be: + + - ``None`` — scanpy default palettes. + - A **list** of colour strings (hex or named). + - A **named palette**: ``"okabe_ito"`` (8 colourblind-safe + colours). + - A **matplotlib colormap name**: ``"tab10"``, ``"Set2"``, etc. + method + Ordering strategy: + + - ``"default"`` — take the first *n* colours in source order. + - ``"contrast"`` — reorder to maximise pairwise perceptual + distance (Oklab space). + - ``"colorblind"`` — reorder to maximise pairwise distance + under worst-case colour-vision deficiency. + - ``"protanopia"`` / ``"deuteranopia"`` / ``"tritanopia"`` — + reorder for a specific colour-vision deficiency. + + The ``spaco*`` methods require spatial data and are only + available via :func:`make_palette_from_data`. + n_random + Random permutations to try (optimisation methods only). + n_swaps + Pairwise swap iterations (optimisation methods only). + seed + Random seed for reproducibility (optimisation methods only). + + Returns + ------- + list[str] + List of *n* hex colour strings. + + Examples + -------- + >>> sdp.pl.make_palette(5) + >>> sdp.pl.make_palette(8, palette="okabe_ito") + >>> sdp.pl.make_palette(10, palette="tab10", method="contrast") + >>> sdp.pl.make_palette(6, palette="tab10", method="colorblind") + """ + if n < 1: + raise ValueError(f"n must be at least 1, got {n}.") + + if method in _SPACO_CVD_TYPES: + raise ValueError(f"Method '{method}' requires spatial data. Use make_palette_from_data() instead.") + + colors = _resolve_palette(palette, n) + + if method == "default": + return [to_hex(to_rgb(c)) for c in colors] + + if method in _CONTRAST_CVD_TYPES: + cvd_type = _CONTRAST_CVD_TYPES[method] + return _optimized_order(colors, colorblind_type=cvd_type, n_random=n_random, n_swaps=n_swaps, seed=seed) + + valid = ", ".join(f"'{m}'" for m in _ALL_METHODS) + raise ValueError(f"Unknown method '{method}'. Choose from {valid}.") + + +def make_palette_from_data( + sdata: sd.SpatialData, + element: str, + color: str, + *, + palette: list[str] | str | None = None, + method: Method = "default", + table_name: str | None = None, + n_neighbors: int = 15, + n_random: int = 5000, + n_swaps: int = 10000, + seed: int = 0, +) -> dict[str, str]: + """Generate a categorical colour palette for a spatial element. + + The *palette* parameter controls **which** colours are used (the source), + while *method* controls **how** they are assigned to categories. + + Parameters + ---------- + sdata + A :class:`spatialdata.SpatialData` object. + element + Name of a shapes or points element in *sdata*. + color + Column name containing categorical labels (in the element itself + for points, or in the linked table for shapes/labels). + palette + Source colours. Accepts the same values as + :func:`make_palette` (*None*, a list, a named palette, or a + matplotlib colormap name). + table_name + Name of the table to use when *color* is looked up from a linked + table. Required when multiple tables annotate the same element. + method + Strategy for assigning colours to categories. Accepts all + methods from :func:`make_palette` plus spatially-aware ones: + + - ``"default"`` — assign in sorted category order (reproduces + the current render-pipeline behaviour). + - ``"contrast"`` / ``"colorblind"`` / ``"protanopia"`` / + ``"deuteranopia"`` / ``"tritanopia"`` — reorder to maximise + perceptual spread (ignores spatial layout). + - ``"spaco"`` — spatially-aware assignment (Jing et al., + *Patterns* 2023). Maximises perceptual contrast between + categories that are spatially interleaved. + - ``"spaco_colorblind"`` — like ``"spaco"`` but optimises under + worst-case colour-vision deficiency (all three types). + - ``"spaco_protanopia"`` / ``"spaco_deuteranopia"`` / + ``"spaco_tritanopia"`` — like ``"spaco"`` but optimises for + a specific colour-vision deficiency. + n_neighbors + Only used with ``spaco`` methods. Number of spatial neighbours + for the interlacement computation. + n_random + Random permutations to try (optimisation methods only). + n_swaps + Pairwise swap iterations (optimisation methods only). + seed + Random seed for reproducibility (optimisation methods only). + + Returns + ------- + dict[str, str] + Mapping from category name to hex colour string. Can be passed + directly as ``palette=`` to any render function. + + Examples + -------- + >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type") + >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", palette="tab10") + >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco") + >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco_colorblind") + >>> sdata.pl.render_shapes("cells", color="cell_type", palette=palette).pl.show() + """ + coords, labels_cat = _resolve_element(sdata, element, color, table_name=table_name) + + categories = list(labels_cat.categories) + n_cat = len(categories) + if n_cat == 0: + raise ValueError(f"No categories found in column '{color}'.") + + colors_list = _resolve_palette(palette, n_cat) + + if method == "default": + return {cat: to_hex(to_rgb(c)) for cat, c in zip(categories, colors_list, strict=True)} + + # Non-spatial contrast methods (same as make_palette but returns dict) + if method in _CONTRAST_CVD_TYPES: + cvd_type = _CONTRAST_CVD_TYPES[method] + reordered = _optimized_order( + colors_list, colorblind_type=cvd_type, n_random=n_random, n_swaps=n_swaps, seed=seed + ) + return dict(zip(categories, reordered, strict=True)) + + # Spaco methods (spatially-aware) + if method in _SPACO_CVD_TYPES: + cvd_type = _SPACO_CVD_TYPES[method] + + # Filter NaN labels + mask = labels_cat.codes != -1 + coords_clean = coords[mask] + labels_clean = np.array(categories)[labels_cat.codes[mask]] + + if len(coords_clean) == 0: + raise ValueError(f"All values in column '{color}' are NaN.") + + rgb = np.array([to_rgb(c) for c in colors_list]) + + if n_cat == 1: + return {categories[0]: to_hex(rgb[0])} + + logger.info(f"Computing spatial interlacement for {n_cat} categories ({len(coords_clean)} cells)...") + inter = _spatial_interlacement(coords_clean, labels_clean, categories, n_neighbors=n_neighbors) + + logger.info("Computing perceptual distance matrix...") + cdist = _perceptual_distance_matrix(rgb, colorblind_type=cvd_type) + + logger.info("Optimizing color assignment...") + rng = np.random.default_rng(seed) + perm = _optimize_assignment(inter, cdist, n_random=n_random, n_swaps=n_swaps, rng=rng) + + return {cat: to_hex(rgb[perm[i]]) for i, cat in enumerate(categories)} + + valid = ", ".join(f"'{m}'" for m in _ALL_METHODS) + raise ValueError(f"Unknown method '{method}'. Choose from {valid}.") diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 2a558361..0a99f395 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -172,7 +172,7 @@ def render_shapes( *, fill_alpha: float | int | None = None, groups: list[str] | str | None = None, - palette: list[str] | str | None = None, + palette: dict[str, str] | list[str] | str | None = None, na_color: ColorLike | None = "default", outline_width: float | int | tuple[float | int, float | int] | None = None, outline_color: ColorLike | tuple[ColorLike] | None = None, @@ -369,7 +369,7 @@ def render_points( *, alpha: float | int | None = None, groups: list[str] | str | None = None, - palette: list[str] | str | None = None, + palette: dict[str, str] | list[str] | str | None = None, na_color: ColorLike | None = "default", cmap: Colormap | str | None = None, norm: Normalize | None = None, @@ -707,7 +707,7 @@ def render_labels( *, groups: list[str] | str | None = None, contour_px: int | None = 3, - palette: list[str] | str | None = None, + palette: dict[str, str] | list[str] | str | None = None, cmap: Colormap | str | None = None, norm: Normalize | None = None, na_color: ColorLike | None = "default", diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 344df5f9..dffb97dd 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -223,7 +223,7 @@ class ShapesRenderParams: col_for_color: str | None = None groups: str | list[str] | None = None contour_px: int | None = None - palette: ListedColormap | list[str] | None = None + palette: ListedColormap | dict[str, str] | list[str] | None = None outline_alpha: tuple[float, float] = (1.0, 1.0) fill_alpha: float = 0.3 scale: float = 1.0 @@ -247,7 +247,7 @@ class PointsRenderParams: color: Color | None = None col_for_color: str | None = None groups: str | list[str] | None = None - palette: ListedColormap | list[str] | None = None + palette: ListedColormap | dict[str, str] | list[str] | None = None alpha: float = 1.0 size: float = 1.0 transfunc: Callable[[float], float] | None = None @@ -288,7 +288,7 @@ class LabelsRenderParams: groups: str | list[str] | None = None contour_px: int | None = None outline: bool = False - palette: ListedColormap | list[str] | None = None + palette: ListedColormap | dict[str, str] | list[str] | None = None outline_alpha: float = 1.0 outline_color: Color | None = None fill_alpha: float = 0.4 diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 8530aec1..7746c8af 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1024,7 +1024,7 @@ def _set_color_source_vec( na_color: Color, element_name: list[str] | str | None = None, groups: list[str] | str | None = None, - palette: list[str] | str | None = None, + palette: dict[str, str] | list[str] | str | None = None, cmap_params: CmapParams | None = None, alpha: float = 1.0, table_name: str | None = None, @@ -1519,7 +1519,7 @@ def _to_hex_no_alpha(color_value: Any) -> str | None: def _modify_categorical_color_mapping( mapping: Mapping[str, str], groups: list[str] | str | None = None, - palette: list[str] | str | None = None, + palette: dict[str, str] | list[str] | str | None = None, ) -> Mapping[str, str]: if groups is None or isinstance(groups, list) and groups[0] is None: return mapping @@ -1577,12 +1577,24 @@ def _get_categorical_color_mapping( cmap_params: CmapParams | None = None, alpha: float = 1, groups: list[str] | str | None = None, - palette: list[str] | str | None = None, + palette: dict[str, str] | list[str] | str | None = None, render_type: Literal["points", "labels"] | None = None, ) -> Mapping[str, str]: if not isinstance(color_source_vector, Categorical): raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}") + # Dict palette (e.g. from make_palette_from_data): use directly as category→color mapping + if isinstance(palette, dict): + na_color_hex = na_color.get_hex_with_alpha() if isinstance(na_color, Color) else str(na_color) + if isinstance(groups, str): + groups = [groups] + if groups is not None: + mapping = {cat: palette.get(cat, na_color_hex) for cat in groups if cat in color_source_vector.categories} + else: + mapping = {cat: palette.get(cat, na_color_hex) for cat in color_source_vector.categories} + mapping["NaN"] = na_color_hex + return mapping + if isinstance(groups, str): groups = [groups] @@ -2395,14 +2407,21 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st palette = param_dict["palette"] - if isinstance(palette, list): + # dict palettes (e.g. from make_palette_from_data) bypass groups validation + if isinstance(palette, dict): + from matplotlib.colors import is_color_like + + invalid = [f"'{k}': '{v}'" for k, v in palette.items() if not is_color_like(v)] + if invalid: + raise ValueError(f"Dict palette contains invalid color values: {', '.join(invalid)}.") + elif isinstance(palette, list): if not all(isinstance(p, str) for p in palette): raise ValueError("If specified, parameter 'palette' must contain only strings.") elif isinstance(palette, str | type(None)) and "palette" in param_dict: param_dict["palette"] = [palette] if palette is not None else None palette_group = param_dict.get("palette") - if element_type in ["shapes", "points", "labels"] and palette_group is not None: + if element_type in ["shapes", "points", "labels"] and palette_group is not None and not isinstance(palette, dict): groups = param_dict.get("groups") if groups is None: raise ValueError("When specifying 'palette', 'groups' must also be specified.") @@ -2542,7 +2561,7 @@ def _validate_label_render_params( fill_alpha: float | int | None, contour_px: int | None, groups: list[str] | str | None, - palette: list[str] | str | None, + palette: dict[str, str] | list[str] | str | None, na_color: ColorLike | None, norm: Normalize | None, outline_alpha: float | int, @@ -2614,7 +2633,7 @@ def _validate_points_render_params( alpha: float | int | None, color: ColorLike | None, groups: list[str] | str | None, - palette: list[str] | str | None, + palette: dict[str, str] | list[str] | str | None, na_color: ColorLike | None, cmap: list[Colormap | str] | Colormap | str | None, norm: Normalize | None, @@ -2682,7 +2701,7 @@ def _validate_shape_render_params( element: str | None, fill_alpha: float | int | None, groups: list[str] | str | None, - palette: list[str] | str | None, + palette: dict[str, str] | list[str] | str | None, color: ColorLike | None, na_color: ColorLike | None, outline_width: float | int | tuple[float | int, float | int] | None, diff --git a/tests/_images/PaletteVisual_dict_palette_hex_labels.png b/tests/_images/PaletteVisual_dict_palette_hex_labels.png new file mode 100644 index 00000000..2273f591 Binary files /dev/null and b/tests/_images/PaletteVisual_dict_palette_hex_labels.png differ diff --git a/tests/_images/PaletteVisual_dict_palette_hex_points.png b/tests/_images/PaletteVisual_dict_palette_hex_points.png new file mode 100644 index 00000000..3d844b33 Binary files /dev/null and b/tests/_images/PaletteVisual_dict_palette_hex_points.png differ diff --git a/tests/_images/PaletteVisual_dict_palette_hex_shapes.png b/tests/_images/PaletteVisual_dict_palette_hex_shapes.png new file mode 100644 index 00000000..0d251b9d Binary files /dev/null and b/tests/_images/PaletteVisual_dict_palette_hex_shapes.png differ diff --git a/tests/_images/PaletteVisual_dict_palette_named_colors_labels.png b/tests/_images/PaletteVisual_dict_palette_named_colors_labels.png new file mode 100644 index 00000000..2273f591 Binary files /dev/null and b/tests/_images/PaletteVisual_dict_palette_named_colors_labels.png differ diff --git a/tests/_images/PaletteVisual_dict_palette_named_colors_points.png b/tests/_images/PaletteVisual_dict_palette_named_colors_points.png new file mode 100644 index 00000000..0477e133 Binary files /dev/null and b/tests/_images/PaletteVisual_dict_palette_named_colors_points.png differ diff --git a/tests/_images/PaletteVisual_dict_palette_named_colors_shapes.png b/tests/_images/PaletteVisual_dict_palette_named_colors_shapes.png new file mode 100644 index 00000000..cb9d9f3a Binary files /dev/null and b/tests/_images/PaletteVisual_dict_palette_named_colors_shapes.png differ diff --git a/tests/pl/test_palette.py b/tests/pl/test_palette.py new file mode 100644 index 00000000..37affabb --- /dev/null +++ b/tests/pl/test_palette.py @@ -0,0 +1,362 @@ +"""Tests for palette generation (issue #210).""" + +from __future__ import annotations + +import matplotlib +import numpy as np +import pandas as pd +import pytest +import scanpy as sc +from matplotlib.colors import to_hex, to_rgb +from spatialdata import SpatialData +from spatialdata.models import PointsModel, ShapesModel, TableModel + +import spatialdata_plot # noqa: F401 — registers accessor +from spatialdata_plot.pl._palette import ( + _optimize_assignment, + _pairwise_oklab_dist, + _perceptual_distance_matrix, + _rgb_to_oklab, + _simulate_cvd, + _spatial_interlacement, + make_palette, + make_palette_from_data, +) +from tests.conftest import DPI, PlotTester, PlotTesterMeta + +matplotlib.use("agg") + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _build_clustered_points_sdata(seed: int = 0) -> SpatialData: + """SpatialData with interleaved A/B clusters near origin and isolated C far away.""" + rng = np.random.default_rng(seed) + coords_a = np.array([[0, 0], [1, 0], [0, 1]], dtype=float) + rng.normal(0, 0.05, (3, 2)) + coords_b = np.array([[0.5, 0.5], [1.5, 0.5], [0.5, 1.5]], dtype=float) + rng.normal(0, 0.05, (3, 2)) + coords_c = np.array([[10, 10], [11, 10], [10, 11]], dtype=float) + rng.normal(0, 0.05, (3, 2)) + + coords = np.vstack([coords_a, coords_b, coords_c]) + labels = pd.Categorical(["A"] * 3 + ["B"] * 3 + ["C"] * 3) + df = pd.DataFrame({"x": coords[:, 0], "y": coords[:, 1], "cell_type": labels}) + return SpatialData(points={"cells": PointsModel.parse(df)}) + + +def _build_shapes_sdata(seed: int = 0) -> SpatialData: + """SpatialData with shapes + linked table containing categorical labels.""" + from anndata import AnnData + from geopandas import GeoDataFrame + from shapely import Point + + rng = np.random.default_rng(seed) + n = 30 + coords = rng.normal(size=(n, 2)) * 5 + gdf = GeoDataFrame({"radius": np.ones(n)}, geometry=[Point(x, y) for x, y in coords]) + gdf.index = pd.RangeIndex(n) + + adata = AnnData( + np.zeros((n, 1)), + obs=pd.DataFrame( + { + "cell_type": pd.Categorical(rng.choice(["X", "Y", "Z"], size=n)), + "instance_id": np.arange(n), + "region": ["my_shapes"] * n, + }, + index=pd.RangeIndex(n).astype(str), + ), + ) + adata = TableModel.parse(adata=adata, region="my_shapes", region_key="region", instance_key="instance_id") + return SpatialData(shapes={"my_shapes": ShapesModel.parse(gdf)}, tables={"table": adata}) + + +@pytest.fixture(scope="module") +def clustered_sdata() -> SpatialData: + return _build_clustered_points_sdata() + + +@pytest.fixture(scope="module") +def shapes_sdata() -> SpatialData: + return _build_shapes_sdata() + + +# --------------------------------------------------------------------------- +# Unit tests: internals +# --------------------------------------------------------------------------- + + +class TestOklab: + def test_black_and_white(self): + lab = _rgb_to_oklab(np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])) + assert lab[0, 0] == pytest.approx(0.0, abs=0.01) + assert lab[1, 0] == pytest.approx(1.0, abs=0.01) + + def test_pairwise_distance_symmetric(self): + d = _pairwise_oklab_dist(_rgb_to_oklab(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=float))) + assert d.shape == (3, 3) + np.testing.assert_allclose(d, d.T) + np.testing.assert_allclose(np.diag(d), 0) + + def test_distinct_colors_have_positive_distance(self): + d = _pairwise_oklab_dist(_rgb_to_oklab(np.array([[1, 0, 0], [0, 0, 1]], dtype=float))) + assert d[0, 1] > 0.1 + + +class TestCVD: + @pytest.mark.parametrize("cvd_type", ["protanopia", "deuteranopia", "tritanopia"]) + def test_output_in_range(self, cvd_type: str): + sim = _simulate_cvd(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=float), cvd_type) + assert sim.shape == (3, 3) + assert np.all((sim >= 0) & (sim <= 1)) + + def test_general_returns_stacked(self): + sim = _simulate_cvd(np.array([[1, 0, 0], [0, 1, 0]], dtype=float), "general") + assert sim.shape == (3, 2, 3) + + @pytest.mark.parametrize("cvd_type", ["protanopia", "deuteranopia"]) + def test_red_green_less_distinct(self, cvd_type: str): + rgb = np.array([[1, 0, 0], [0, 1, 0]], dtype=float) + assert _perceptual_distance_matrix(rgb, colorblind_type=cvd_type)[0, 1] < _perceptual_distance_matrix(rgb)[0, 1] + + +class TestSpatialInterlacement: + def test_interleaved_higher_than_separated(self): + coords = np.array([[0, 0], [1, 0], [0.5, 0.5], [1.5, 0.5], [10, 10], [11, 10]]) + mat = _spatial_interlacement(coords, np.array(["A", "B", "A", "B", "C", "C"]), ["A", "B", "C"], n_neighbors=3) + assert mat[0, 1] > mat[0, 2] + assert mat[0, 1] > mat[1, 2] + + def test_diagonal_is_zero(self): + mat = _spatial_interlacement(np.array([[0, 0], [1, 0], [0.5, 0.5]]), np.array(["A", "B", "A"]), ["A", "B"], 2) + np.testing.assert_allclose(np.diag(mat), 0) + + def test_symmetric(self): + rng = np.random.default_rng(42) + mat = _spatial_interlacement(rng.normal(size=(50, 2)), np.array(list("ABCDE") * 10), list("ABCDE"), 5) + np.testing.assert_allclose(mat, mat.T) + + +class TestOptimizer: + def test_single_category(self): + assert list(_optimize_assignment(np.zeros((1, 1)), np.zeros((1, 1)))) == [0] + + def test_two_categories(self): + perm = _optimize_assignment(np.array([[0, 1], [1, 0]], dtype=float), np.array([[0, 10], [10, 0]], dtype=float)) + assert set(perm) == {0, 1} + + def test_deterministic_with_seed(self): + inter = np.random.default_rng(0).random((5, 5)) + inter = np.maximum(inter, inter.T) + np.fill_diagonal(inter, 0) + cdist = np.random.default_rng(1).random((5, 5)) + cdist = np.maximum(cdist, cdist.T) + np.fill_diagonal(cdist, 0) + + p1 = _optimize_assignment(inter, cdist, rng=np.random.default_rng(42)) + p2 = _optimize_assignment(inter, cdist, rng=np.random.default_rng(42)) + np.testing.assert_array_equal(p1, p2) + + +# --------------------------------------------------------------------------- +# Tests: make_palette +# --------------------------------------------------------------------------- + + +class TestMakePalette: + def test_default_returns_n_hex_colors(self): + result = make_palette(5) + assert len(result) == 5 + assert isinstance(result, list) + assert all(c.startswith("#") for c in result) + + @pytest.mark.parametrize("palette", ["okabe_ito", "tab10", None]) + def test_palette_sources(self, palette: str | None): + result = make_palette(4, palette=palette) + assert len(result) == 4 + + def test_custom_list(self): + colors = ["#ff0000", "#00ff00", "#0000ff"] + assert make_palette(3, palette=colors) == [to_hex(to_rgb(c)) for c in colors] + + @pytest.mark.parametrize("method", ["contrast", "colorblind", "deuteranopia"]) + def test_optimization_methods_produce_permutation(self, method: str): + colors = ["#ff0000", "#ff1100", "#0000ff", "#00ff00"] + result = make_palette(4, palette=colors, method=method, seed=42) + assert set(result) == {to_hex(to_rgb(c)) for c in colors} + + def test_deterministic(self): + assert make_palette(5, method="contrast", seed=42) == make_palette(5, method="contrast", seed=42) + + def test_n_zero_raises(self): + with pytest.raises(ValueError, match="at least 1"): + make_palette(0) + + def test_too_few_colors_raises(self): + with pytest.raises(ValueError, match="needed"): + make_palette(10, palette=["red", "blue"]) + + @pytest.mark.parametrize("method", ["spaco", "spaco_colorblind"]) + def test_spaco_methods_raise(self, method: str): + with pytest.raises(ValueError, match="requires spatial data"): + make_palette(3, method=method) # type: ignore[arg-type] + + def test_unknown_method_raises(self): + with pytest.raises(ValueError, match="Unknown method"): + make_palette(3, method="invalid") # type: ignore[arg-type] + + def test_unknown_palette_name_raises(self): + with pytest.raises(ValueError, match="Unknown palette name"): + make_palette(3, palette="nonexistent_palette") + + +# --------------------------------------------------------------------------- +# Tests: make_palette_from_data +# --------------------------------------------------------------------------- + + +class TestMakePaletteFromData: + def test_default_returns_dict(self, clustered_sdata: SpatialData): + result = make_palette_from_data(clustered_sdata, "cells", "cell_type") + assert isinstance(result, dict) + assert set(result.keys()) == {"A", "B", "C"} + assert all(v.startswith("#") for v in result.values()) + + def test_default_matches_scanpy_order(self, clustered_sdata: SpatialData): + from scanpy.plotting.palettes import default_20 + + result = make_palette_from_data(clustered_sdata, "cells", "cell_type") + for i, cat in enumerate(sorted(result.keys())): + assert result[cat] == to_hex(to_rgb(default_20[i])) + + def test_custom_palette(self, clustered_sdata: SpatialData): + colors = ["#ff0000", "#00ff00", "#0000ff"] + result = make_palette_from_data(clustered_sdata, "cells", "cell_type", palette=colors) + assert list(result.values()) == [to_hex(to_rgb(c)) for c in colors] + + @pytest.mark.parametrize("palette", ["okabe_ito", "tab10"]) + def test_named_palette_sources(self, clustered_sdata: SpatialData, palette: str): + result = make_palette_from_data(clustered_sdata, "cells", "cell_type", palette=palette) + assert isinstance(result, dict) and len(result) == 3 + + @pytest.mark.parametrize( + "method", + ["contrast", "colorblind", "spaco", "spaco_colorblind", "spaco_deuteranopia"], + ) + def test_all_methods_return_valid_dict(self, clustered_sdata: SpatialData, method: str): + result = make_palette_from_data(clustered_sdata, "cells", "cell_type", method=method, seed=42) + assert isinstance(result, dict) + assert set(result.keys()) == {"A", "B", "C"} + + def test_spaco_deterministic(self, clustered_sdata: SpatialData): + r1 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=42) + r2 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=42) + assert r1 == r2 + + def test_spaco_different_seeds_can_differ(self, clustered_sdata: SpatialData): + r1 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=0) + r2 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=999) + assert set(r1.keys()) == set(r2.keys()) + + def test_spaco_custom_palette_is_permutation(self, clustered_sdata: SpatialData): + colors = ["#ff0000", "#00ff00", "#0000ff"] + result = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", palette=colors, seed=42) + assert set(result.values()) == {to_hex(to_rgb(c)) for c in colors} + + def test_spaco_single_category(self): + df = pd.DataFrame({"x": [0.0, 1.0], "y": [0.0, 1.0], "ct": pd.Categorical(["A", "A"])}) + sdata = SpatialData(points={"pts": PointsModel.parse(df)}) + result = make_palette_from_data(sdata, "pts", "ct", method="spaco", seed=0) + assert len(result) == 1 and "A" in result + + def test_spaco_nan_labels_filtered(self): + df = pd.DataFrame( + {"x": [0.0, 1.0, 0.0, 10.0], "y": [0.0, 0.0, 1.0, 10.0], "ct": pd.Categorical(["A", "B", "A", None])} + ) + sdata = SpatialData(points={"pts": PointsModel.parse(df)}) + result = make_palette_from_data(sdata, "pts", "ct", method="spaco", seed=0) + assert {"A", "B"} <= set(result.keys()) + + def test_shapes_with_table(self, shapes_sdata: SpatialData): + result = make_palette_from_data(shapes_sdata, "my_shapes", "cell_type", method="spaco", seed=42) + assert isinstance(result, dict) + assert set(result.keys()) == {"X", "Y", "Z"} + + def test_interleaved_get_distinct_colors(self): + sdata = _build_clustered_points_sdata(seed=0) + palette = ["#ff0000", "#ff1100", "#0000ff"] + result = make_palette_from_data(sdata, "cells", "cell_type", method="spaco", palette=palette, seed=0) + # A and B (interleaved) should not both get red-ish colors + assert result["A"] == "#0000ff" or result["B"] == "#0000ff" + + +# --------------------------------------------------------------------------- +# Error cases +# --------------------------------------------------------------------------- + + +class TestMakePaletteFromDataErrors: + def test_too_few_colors(self, clustered_sdata: SpatialData): + with pytest.raises(ValueError, match="needed"): + make_palette_from_data(clustered_sdata, "cells", "cell_type", palette=["red", "blue"]) + + def test_missing_element(self, clustered_sdata: SpatialData): + with pytest.raises(KeyError, match="not found"): + make_palette_from_data(clustered_sdata, "nonexistent", "cell_type") + + def test_missing_column(self, clustered_sdata: SpatialData): + with pytest.raises(KeyError, match="not found"): + make_palette_from_data(clustered_sdata, "cells", "nonexistent_col") + + def test_unknown_method(self, clustered_sdata: SpatialData): + with pytest.raises(ValueError, match="Unknown method"): + make_palette_from_data(clustered_sdata, "cells", "cell_type", method="invalid") # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Integration: dict palette through render pipeline +# --------------------------------------------------------------------------- + + +class TestDictPalette: + def test_dict_palette_in_render_points(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points("blobs_points", color="genes", palette={"0": "#ff0000", "1": "#00ff00"}) + + def test_dict_palette_in_render_labels(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels("blobs_labels", color="region", palette={"blobs_labels": "#ff0000"}) + + +# --------------------------------------------------------------------------- +# Visual tests +# --------------------------------------------------------------------------- + +sc.pl.set_rcParams_defaults() +sc.set_figure_params(dpi=DPI, color_map="viridis") + + +class TestPaletteVisual(PlotTester, metaclass=PlotTesterMeta): + def test_plot_dict_palette_hex_points(self, sdata_blobs: SpatialData): + palette = make_palette_from_data(sdata_blobs, "blobs_points", "genes", palette="okabe_ito") + sdata_blobs.pl.render_points("blobs_points", color="genes", palette=palette).pl.show() + + def test_plot_dict_palette_hex_shapes(self, sdata_blobs: SpatialData): + sdata_blobs["blobs_polygons"]["cat_col"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category") + palette = make_palette_from_data(sdata_blobs, "blobs_polygons", "cat_col", palette="okabe_ito") + sdata_blobs.pl.render_shapes("blobs_polygons", color="cat_col", palette=palette).pl.show() + + def test_plot_dict_palette_hex_labels(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels("blobs_labels", color="region", palette={"blobs_labels": "#E69F00"}).pl.show() + + def test_plot_dict_palette_named_colors_points(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_points( + "blobs_points", color="genes", palette={"gene_a": "red", "gene_b": "dodgerblue"} + ).pl.show() + + def test_plot_dict_palette_named_colors_shapes(self, sdata_blobs: SpatialData): + sdata_blobs["blobs_polygons"]["cat_col"] = pd.Series(["a", "b", "a", "b", "a"], dtype="category") + sdata_blobs.pl.render_shapes( + "blobs_polygons", color="cat_col", palette={"a": "forestgreen", "b": "orchid"} + ).pl.show() + + def test_plot_dict_palette_named_colors_labels(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels("blobs_labels", color="region", palette={"blobs_labels": "coral"}).pl.show()