Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 18 additions & 181 deletions src/spatialdata_plot/pl/_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
- :func:`make_palette` — produce *n* colours, optionally reordered for
maximum perceptual contrast or colourblind accessibility.
- :func:`make_palette_from_data` — like :func:`make_palette` but derives
the number of colours and (for ``spaco`` methods) the assignment order
from a :class:`~spatialdata.SpatialData` element.
the number of colours from a :class:`~spatialdata.SpatialData` element.

Both share the same *palette* / *method* vocabulary. The *palette*
parameter controls **which** colours are used (the source), while
Expand All @@ -22,13 +21,8 @@
from matplotlib.colors import ListedColormap, to_hex, to_rgb
from matplotlib.pyplot import colormaps as mpl_colormaps
from scanpy.plotting.palettes import default_20, default_28, default_102
from scipy.spatial import cKDTree

from spatialdata_plot._logging import logger

if TYPE_CHECKING:
from collections.abc import Sequence

import spatialdata as sd

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -163,9 +157,6 @@ def _optimize_assignment(
) -> np.ndarray:
"""Find a permutation that maximizes ``sum(weights * color_dist[perm, perm])``.

Works for both spatial interlacement weights (spaco) and uniform
weights (pure contrast maximization).

Returns an index array: ``perm[category_idx] = color_idx``.
"""
if rng is None:
Expand Down Expand Up @@ -233,56 +224,6 @@ def _optimized_order(
return [to_hex(rgb[perm[i]]) for i in range(n)]


# ---------------------------------------------------------------------------
# Spatial interlacement (spaco-specific)
# ---------------------------------------------------------------------------


def _spatial_interlacement(
coords: np.ndarray,
labels: np.ndarray,
categories: Sequence[str],
n_neighbors: int = 15,
) -> np.ndarray:
"""Build a symmetric interlacement matrix (n_categories × n_categories).

Entry (i, j) reflects how much categories i and j are spatially
interleaved, measured by inverse-distance-weighted neighbor counts.
"""
n_cat = len(categories)
cat_to_idx = {c: i for i, c in enumerate(categories)}
label_idx = np.array([cat_to_idx[l] for l in labels])

tree = cKDTree(coords)
dists, indices = tree.query(coords, k=min(n_neighbors + 1, len(coords)))

# Vectorized accumulation (avoids Python double-loop over cells × neighbors)
neighbor_dists = dists[:, 1:]
neighbor_indices = indices[:, 1:]
cell_cats = label_idx
neighbor_cats = label_idx[neighbor_indices]

# Mask: different category and positive distance
cross_cat = neighbor_cats != cell_cats[:, np.newaxis]
valid_dist = neighbor_dists > 0
mask = cross_cat & valid_dist

weights = np.where(mask, 1.0 / np.where(neighbor_dists > 0, neighbor_dists, 1.0), 0.0)

rows = np.broadcast_to(cell_cats[:, np.newaxis], neighbor_cats.shape)[mask]
cols = neighbor_cats[mask]
vals = weights[mask]

mat = np.zeros((n_cat, n_cat), dtype=np.float64)
np.add.at(mat, (rows, cols), vals)

mat = np.maximum(mat, mat.T)
max_val = mat.max()
if max_val > 0:
mat /= max_val
return mat # type: ignore[no-any-return]


# ---------------------------------------------------------------------------
# Palette resolution
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -339,35 +280,24 @@ def _resolve_element(
element: str,
color: str,
table_name: str | None = None,
) -> tuple[np.ndarray, pd.Categorical]:
"""Extract coordinates and categorical labels from a SpatialData element.
) -> pd.Categorical:
"""Extract categorical labels from a SpatialData element.

Coordinates come from the element geometry (shapes) or x/y columns
(points). Labels come from a column on the element itself, or from
a linked table (joined on the instance key to guarantee alignment).
Labels come from a column on the element itself, or from a linked
table (joined on the instance key to guarantee alignment).
"""
if element in sdata.shapes:
gdf = sdata.shapes[element]
coords = np.column_stack([gdf.geometry.centroid.x, gdf.geometry.centroid.y])
if color in gdf.columns:
labels_series = gdf[color]
else:
labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name)
# Align coords to table rows via matched instance indices
coords = coords[matched_indices]
labels_series = _get_labels_from_table(sdata, element, color, table_name)
elif element in sdata.points:
ddf = sdata.points[element]
if "x" not in ddf.columns or "y" not in ddf.columns:
raise ValueError(f"Points element '{element}' does not have 'x' and 'y' columns.")
if color in ddf.columns:
df = ddf[["x", "y", color]].compute()
coords = df[["x", "y"]].values.astype(np.float64)
labels_series = df[color]
labels_series = ddf[[color]].compute()[color]
else:
df = ddf[["x", "y"]].compute()
coords = df[["x", "y"]].values.astype(np.float64)
labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name)
coords = coords[matched_indices]
labels_series = _get_labels_from_table(sdata, element, color, table_name)
else:
available = list(sdata.shapes.keys()) + list(sdata.points.keys())
raise KeyError(
Expand All @@ -376,24 +306,16 @@ def _resolve_element(
)

is_categorical = isinstance(getattr(labels_series, "dtype", None), pd.CategoricalDtype)
labels_cat = labels_series.values if is_categorical else pd.Categorical(labels_series)
return coords, labels_cat
return labels_series.values if is_categorical else pd.Categorical(labels_series)


def _get_labels_from_table(
sdata: sd.SpatialData,
element: str,
color: str,
table_name: str | None = None,
) -> tuple[pd.Series, np.ndarray]:
"""Extract a column from the table linked to an element.

Returns (labels_series, element_indices) where element_indices maps
each table row to its position in the element, ensuring coord-label
alignment.
"""
from spatialdata.models import get_table_keys

) -> pd.Series:
"""Extract a column from the table linked to an element."""
matches: list[str] = []
for name in sdata.tables:
table = sdata.tables[name]
Expand Down Expand Up @@ -423,29 +345,7 @@ def _get_labels_from_table(
)

table = sdata.tables[resolved_name]
_, _, instance_key = get_table_keys(table)

# Join on instance key to align table rows with element positions
instance_ids = table.obs[instance_key].values
element_index = sdata.shapes[element].index if element in sdata.shapes else sdata.points[element].compute().index

# Map each table instance_id to its position in the element index
element_idx_map = {val: i for i, val in enumerate(element_index)}
matched_indices = []
valid_mask = []
for iid in instance_ids:
if iid in element_idx_map:
matched_indices.append(element_idx_map[iid])
valid_mask.append(True)
else:
valid_mask.append(False)

valid_mask_arr = np.array(valid_mask)
if not any(valid_mask):
raise ValueError(f"No matching instance keys between table '{resolved_name}' and element '{element}'.")

labels = table.obs.loc[valid_mask_arr, color]
return labels.reset_index(drop=True), np.array(matched_indices)
return table.obs[color].reset_index(drop=True)


# ---------------------------------------------------------------------------
Expand All @@ -461,16 +361,7 @@ def _get_labels_from_table(
"tritanopia": "tritanopia",
}

# Maps spaco methods → CVD type (None = normal vision).
_SPACO_CVD_TYPES: dict[str, str | None] = {
"spaco": None,
"spaco_colorblind": "general",
"spaco_protanopia": "protanopia",
"spaco_deuteranopia": "deuteranopia",
"spaco_tritanopia": "tritanopia",
}

_ALL_METHODS = sorted({"default", *_CONTRAST_CVD_TYPES, *_SPACO_CVD_TYPES})
_ALL_METHODS = sorted({"default", *_CONTRAST_CVD_TYPES})


# ---------------------------------------------------------------------------
Expand All @@ -484,11 +375,6 @@ def _get_labels_from_table(
"protanopia",
"deuteranopia",
"tritanopia",
"spaco",
"spaco_colorblind",
"spaco_protanopia",
"spaco_deuteranopia",
"spaco_tritanopia",
]


Expand Down Expand Up @@ -528,9 +414,6 @@ def make_palette(
under worst-case colour-vision deficiency.
- ``"protanopia"`` / ``"deuteranopia"`` / ``"tritanopia"`` —
reorder for a specific colour-vision deficiency.

The ``spaco*`` methods require spatial data and are only
available via :func:`make_palette_from_data`.
n_random
Random permutations to try (optimisation methods only).
n_swaps
Expand All @@ -553,9 +436,6 @@ def make_palette(
if n < 1:
raise ValueError(f"n must be at least 1, got {n}.")

if method in _SPACO_CVD_TYPES:
raise ValueError(f"Method '{method}' requires spatial data. Use make_palette_from_data() instead.")

colors = _resolve_palette(palette, n)

if method == "default":
Expand All @@ -577,7 +457,6 @@ def make_palette_from_data(
palette: list[str] | str | None = None,
method: Method = "default",
table_name: str | None = None,
n_neighbors: int = 15,
n_random: int = 5000,
n_swaps: int = 10000,
seed: int = 0,
Expand Down Expand Up @@ -605,25 +484,13 @@ def make_palette_from_data(
Name of the table to use when *color* is looked up from a linked
table. Required when multiple tables annotate the same element.
method
Strategy for assigning colours to categories. Accepts all
methods from :func:`make_palette` plus spatially-aware ones:
Strategy for assigning colours to categories:

- ``"default"`` — assign in sorted category order (reproduces
the current render-pipeline behaviour).
- ``"contrast"`` / ``"colorblind"`` / ``"protanopia"`` /
``"deuteranopia"`` / ``"tritanopia"`` — reorder to maximise
perceptual spread (ignores spatial layout).
- ``"spaco"`` — spatially-aware assignment (Jing et al.,
*Patterns* 2023). Maximises perceptual contrast between
categories that are spatially interleaved.
- ``"spaco_colorblind"`` — like ``"spaco"`` but optimises under
worst-case colour-vision deficiency (all three types).
- ``"spaco_protanopia"`` / ``"spaco_deuteranopia"`` /
``"spaco_tritanopia"`` — like ``"spaco"`` but optimises for
a specific colour-vision deficiency.
n_neighbors
Only used with ``spaco`` methods. Number of spatial neighbours
for the interlacement computation.
perceptual spread.
n_random
Random permutations to try (optimisation methods only).
n_swaps
Expand All @@ -641,11 +508,11 @@ def make_palette_from_data(
--------
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type")
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", palette="tab10")
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco")
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco_colorblind")
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="contrast")
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="colorblind")
>>> sdata.pl.render_shapes("cells", color="cell_type", palette=palette).pl.show()
"""
coords, labels_cat = _resolve_element(sdata, element, color, table_name=table_name)
labels_cat = _resolve_element(sdata, element, color, table_name=table_name)

categories = list(labels_cat.categories)
n_cat = len(categories)
Expand All @@ -657,42 +524,12 @@ def make_palette_from_data(
if method == "default":
return {cat: to_hex(to_rgb(c)) for cat, c in zip(categories, colors_list, strict=True)}

# Non-spatial contrast methods (same as make_palette but returns dict)
if method in _CONTRAST_CVD_TYPES:
cvd_type = _CONTRAST_CVD_TYPES[method]
reordered = _optimized_order(
colors_list, colorblind_type=cvd_type, n_random=n_random, n_swaps=n_swaps, seed=seed
)
return dict(zip(categories, reordered, strict=True))

# Spaco methods (spatially-aware)
if method in _SPACO_CVD_TYPES:
cvd_type = _SPACO_CVD_TYPES[method]

# Filter NaN labels
mask = labels_cat.codes != -1
coords_clean = coords[mask]
labels_clean = np.array(categories)[labels_cat.codes[mask]]

if len(coords_clean) == 0:
raise ValueError(f"All values in column '{color}' are NaN.")

rgb = np.array([to_rgb(c) for c in colors_list])

if n_cat == 1:
return {categories[0]: to_hex(rgb[0])}

logger.info(f"Computing spatial interlacement for {n_cat} categories ({len(coords_clean)} cells)...")
inter = _spatial_interlacement(coords_clean, labels_clean, categories, n_neighbors=n_neighbors)

logger.info("Computing perceptual distance matrix...")
cdist = _perceptual_distance_matrix(rgb, colorblind_type=cvd_type)

logger.info("Optimizing color assignment...")
rng = np.random.default_rng(seed)
perm = _optimize_assignment(inter, cdist, n_random=n_random, n_swaps=n_swaps, rng=rng)

return {cat: to_hex(rgb[perm[i]]) for i, cat in enumerate(categories)}

valid = ", ".join(f"'{m}'" for m in _ALL_METHODS)
raise ValueError(f"Unknown method '{method}'. Choose from {valid}.")
Loading
Loading