diff --git a/src/spatialdata_plot/pl/_palette.py b/src/spatialdata_plot/pl/_palette.py index 05e2617b..4159de3a 100644 --- a/src/spatialdata_plot/pl/_palette.py +++ b/src/spatialdata_plot/pl/_palette.py @@ -5,8 +5,7 @@ - :func:`make_palette` — produce *n* colours, optionally reordered for maximum perceptual contrast or colourblind accessibility. - :func:`make_palette_from_data` — like :func:`make_palette` but derives - the number of colours and (for ``spaco`` methods) the assignment order - from a :class:`~spatialdata.SpatialData` element. + the number of colours from a :class:`~spatialdata.SpatialData` element. Both share the same *palette* / *method* vocabulary. The *palette* parameter controls **which** colours are used (the source), while @@ -22,13 +21,8 @@ from matplotlib.colors import ListedColormap, to_hex, to_rgb from matplotlib.pyplot import colormaps as mpl_colormaps from scanpy.plotting.palettes import default_20, default_28, default_102 -from scipy.spatial import cKDTree - -from spatialdata_plot._logging import logger if TYPE_CHECKING: - from collections.abc import Sequence - import spatialdata as sd # --------------------------------------------------------------------------- @@ -163,9 +157,6 @@ def _optimize_assignment( ) -> np.ndarray: """Find a permutation that maximizes ``sum(weights * color_dist[perm, perm])``. - Works for both spatial interlacement weights (spaco) and uniform - weights (pure contrast maximization). - Returns an index array: ``perm[category_idx] = color_idx``. """ if rng is None: @@ -233,56 +224,6 @@ def _optimized_order( return [to_hex(rgb[perm[i]]) for i in range(n)] -# --------------------------------------------------------------------------- -# Spatial interlacement (spaco-specific) -# --------------------------------------------------------------------------- - - -def _spatial_interlacement( - coords: np.ndarray, - labels: np.ndarray, - categories: Sequence[str], - n_neighbors: int = 15, -) -> np.ndarray: - """Build a symmetric interlacement matrix (n_categories × n_categories). - - Entry (i, j) reflects how much categories i and j are spatially - interleaved, measured by inverse-distance-weighted neighbor counts. - """ - n_cat = len(categories) - cat_to_idx = {c: i for i, c in enumerate(categories)} - label_idx = np.array([cat_to_idx[l] for l in labels]) - - tree = cKDTree(coords) - dists, indices = tree.query(coords, k=min(n_neighbors + 1, len(coords))) - - # Vectorized accumulation (avoids Python double-loop over cells × neighbors) - neighbor_dists = dists[:, 1:] - neighbor_indices = indices[:, 1:] - cell_cats = label_idx - neighbor_cats = label_idx[neighbor_indices] - - # Mask: different category and positive distance - cross_cat = neighbor_cats != cell_cats[:, np.newaxis] - valid_dist = neighbor_dists > 0 - mask = cross_cat & valid_dist - - weights = np.where(mask, 1.0 / np.where(neighbor_dists > 0, neighbor_dists, 1.0), 0.0) - - rows = np.broadcast_to(cell_cats[:, np.newaxis], neighbor_cats.shape)[mask] - cols = neighbor_cats[mask] - vals = weights[mask] - - mat = np.zeros((n_cat, n_cat), dtype=np.float64) - np.add.at(mat, (rows, cols), vals) - - mat = np.maximum(mat, mat.T) - max_val = mat.max() - if max_val > 0: - mat /= max_val - return mat # type: ignore[no-any-return] - - # --------------------------------------------------------------------------- # Palette resolution # --------------------------------------------------------------------------- @@ -339,35 +280,24 @@ def _resolve_element( element: str, color: str, table_name: str | None = None, -) -> tuple[np.ndarray, pd.Categorical]: - """Extract coordinates and categorical labels from a SpatialData element. +) -> pd.Categorical: + """Extract categorical labels from a SpatialData element. - Coordinates come from the element geometry (shapes) or x/y columns - (points). Labels come from a column on the element itself, or from - a linked table (joined on the instance key to guarantee alignment). + Labels come from a column on the element itself, or from a linked + table (joined on the instance key to guarantee alignment). """ if element in sdata.shapes: gdf = sdata.shapes[element] - coords = np.column_stack([gdf.geometry.centroid.x, gdf.geometry.centroid.y]) if color in gdf.columns: labels_series = gdf[color] else: - labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name) - # Align coords to table rows via matched instance indices - coords = coords[matched_indices] + labels_series = _get_labels_from_table(sdata, element, color, table_name) elif element in sdata.points: ddf = sdata.points[element] - if "x" not in ddf.columns or "y" not in ddf.columns: - raise ValueError(f"Points element '{element}' does not have 'x' and 'y' columns.") if color in ddf.columns: - df = ddf[["x", "y", color]].compute() - coords = df[["x", "y"]].values.astype(np.float64) - labels_series = df[color] + labels_series = ddf[[color]].compute()[color] else: - df = ddf[["x", "y"]].compute() - coords = df[["x", "y"]].values.astype(np.float64) - labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name) - coords = coords[matched_indices] + labels_series = _get_labels_from_table(sdata, element, color, table_name) else: available = list(sdata.shapes.keys()) + list(sdata.points.keys()) raise KeyError( @@ -376,8 +306,7 @@ def _resolve_element( ) is_categorical = isinstance(getattr(labels_series, "dtype", None), pd.CategoricalDtype) - labels_cat = labels_series.values if is_categorical else pd.Categorical(labels_series) - return coords, labels_cat + return labels_series.values if is_categorical else pd.Categorical(labels_series) def _get_labels_from_table( @@ -385,15 +314,8 @@ def _get_labels_from_table( element: str, color: str, table_name: str | None = None, -) -> tuple[pd.Series, np.ndarray]: - """Extract a column from the table linked to an element. - - Returns (labels_series, element_indices) where element_indices maps - each table row to its position in the element, ensuring coord-label - alignment. - """ - from spatialdata.models import get_table_keys - +) -> pd.Series: + """Extract a column from the table linked to an element.""" matches: list[str] = [] for name in sdata.tables: table = sdata.tables[name] @@ -423,29 +345,7 @@ def _get_labels_from_table( ) table = sdata.tables[resolved_name] - _, _, instance_key = get_table_keys(table) - - # Join on instance key to align table rows with element positions - instance_ids = table.obs[instance_key].values - element_index = sdata.shapes[element].index if element in sdata.shapes else sdata.points[element].compute().index - - # Map each table instance_id to its position in the element index - element_idx_map = {val: i for i, val in enumerate(element_index)} - matched_indices = [] - valid_mask = [] - for iid in instance_ids: - if iid in element_idx_map: - matched_indices.append(element_idx_map[iid]) - valid_mask.append(True) - else: - valid_mask.append(False) - - valid_mask_arr = np.array(valid_mask) - if not any(valid_mask): - raise ValueError(f"No matching instance keys between table '{resolved_name}' and element '{element}'.") - - labels = table.obs.loc[valid_mask_arr, color] - return labels.reset_index(drop=True), np.array(matched_indices) + return table.obs[color].reset_index(drop=True) # --------------------------------------------------------------------------- @@ -461,16 +361,7 @@ def _get_labels_from_table( "tritanopia": "tritanopia", } -# Maps spaco methods → CVD type (None = normal vision). -_SPACO_CVD_TYPES: dict[str, str | None] = { - "spaco": None, - "spaco_colorblind": "general", - "spaco_protanopia": "protanopia", - "spaco_deuteranopia": "deuteranopia", - "spaco_tritanopia": "tritanopia", -} - -_ALL_METHODS = sorted({"default", *_CONTRAST_CVD_TYPES, *_SPACO_CVD_TYPES}) +_ALL_METHODS = sorted({"default", *_CONTRAST_CVD_TYPES}) # --------------------------------------------------------------------------- @@ -484,11 +375,6 @@ def _get_labels_from_table( "protanopia", "deuteranopia", "tritanopia", - "spaco", - "spaco_colorblind", - "spaco_protanopia", - "spaco_deuteranopia", - "spaco_tritanopia", ] @@ -528,9 +414,6 @@ def make_palette( under worst-case colour-vision deficiency. - ``"protanopia"`` / ``"deuteranopia"`` / ``"tritanopia"`` — reorder for a specific colour-vision deficiency. - - The ``spaco*`` methods require spatial data and are only - available via :func:`make_palette_from_data`. n_random Random permutations to try (optimisation methods only). n_swaps @@ -553,9 +436,6 @@ def make_palette( if n < 1: raise ValueError(f"n must be at least 1, got {n}.") - if method in _SPACO_CVD_TYPES: - raise ValueError(f"Method '{method}' requires spatial data. Use make_palette_from_data() instead.") - colors = _resolve_palette(palette, n) if method == "default": @@ -577,7 +457,6 @@ def make_palette_from_data( palette: list[str] | str | None = None, method: Method = "default", table_name: str | None = None, - n_neighbors: int = 15, n_random: int = 5000, n_swaps: int = 10000, seed: int = 0, @@ -605,25 +484,13 @@ def make_palette_from_data( Name of the table to use when *color* is looked up from a linked table. Required when multiple tables annotate the same element. method - Strategy for assigning colours to categories. Accepts all - methods from :func:`make_palette` plus spatially-aware ones: + Strategy for assigning colours to categories: - ``"default"`` — assign in sorted category order (reproduces the current render-pipeline behaviour). - ``"contrast"`` / ``"colorblind"`` / ``"protanopia"`` / ``"deuteranopia"`` / ``"tritanopia"`` — reorder to maximise - perceptual spread (ignores spatial layout). - - ``"spaco"`` — spatially-aware assignment (Jing et al., - *Patterns* 2023). Maximises perceptual contrast between - categories that are spatially interleaved. - - ``"spaco_colorblind"`` — like ``"spaco"`` but optimises under - worst-case colour-vision deficiency (all three types). - - ``"spaco_protanopia"`` / ``"spaco_deuteranopia"`` / - ``"spaco_tritanopia"`` — like ``"spaco"`` but optimises for - a specific colour-vision deficiency. - n_neighbors - Only used with ``spaco`` methods. Number of spatial neighbours - for the interlacement computation. + perceptual spread. n_random Random permutations to try (optimisation methods only). n_swaps @@ -641,11 +508,11 @@ def make_palette_from_data( -------- >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type") >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", palette="tab10") - >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco") - >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco_colorblind") + >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="contrast") + >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="colorblind") >>> sdata.pl.render_shapes("cells", color="cell_type", palette=palette).pl.show() """ - coords, labels_cat = _resolve_element(sdata, element, color, table_name=table_name) + labels_cat = _resolve_element(sdata, element, color, table_name=table_name) categories = list(labels_cat.categories) n_cat = len(categories) @@ -657,7 +524,6 @@ def make_palette_from_data( if method == "default": return {cat: to_hex(to_rgb(c)) for cat, c in zip(categories, colors_list, strict=True)} - # Non-spatial contrast methods (same as make_palette but returns dict) if method in _CONTRAST_CVD_TYPES: cvd_type = _CONTRAST_CVD_TYPES[method] reordered = _optimized_order( @@ -665,34 +531,5 @@ def make_palette_from_data( ) return dict(zip(categories, reordered, strict=True)) - # Spaco methods (spatially-aware) - if method in _SPACO_CVD_TYPES: - cvd_type = _SPACO_CVD_TYPES[method] - - # Filter NaN labels - mask = labels_cat.codes != -1 - coords_clean = coords[mask] - labels_clean = np.array(categories)[labels_cat.codes[mask]] - - if len(coords_clean) == 0: - raise ValueError(f"All values in column '{color}' are NaN.") - - rgb = np.array([to_rgb(c) for c in colors_list]) - - if n_cat == 1: - return {categories[0]: to_hex(rgb[0])} - - logger.info(f"Computing spatial interlacement for {n_cat} categories ({len(coords_clean)} cells)...") - inter = _spatial_interlacement(coords_clean, labels_clean, categories, n_neighbors=n_neighbors) - - logger.info("Computing perceptual distance matrix...") - cdist = _perceptual_distance_matrix(rgb, colorblind_type=cvd_type) - - logger.info("Optimizing color assignment...") - rng = np.random.default_rng(seed) - perm = _optimize_assignment(inter, cdist, n_random=n_random, n_swaps=n_swaps, rng=rng) - - return {cat: to_hex(rgb[perm[i]]) for i, cat in enumerate(categories)} - valid = ", ".join(f"'{m}'" for m in _ALL_METHODS) raise ValueError(f"Unknown method '{method}'. Choose from {valid}.") diff --git a/tests/pl/test_palette.py b/tests/pl/test_palette.py index 37affabb..311209d0 100644 --- a/tests/pl/test_palette.py +++ b/tests/pl/test_palette.py @@ -18,7 +18,6 @@ _perceptual_distance_matrix, _rgb_to_oklab, _simulate_cvd, - _spatial_interlacement, make_palette, make_palette_from_data, ) @@ -32,13 +31,10 @@ def _build_clustered_points_sdata(seed: int = 0) -> SpatialData: - """SpatialData with interleaved A/B clusters near origin and isolated C far away.""" + """SpatialData with three categorical labels (A, B, C) on a points element.""" rng = np.random.default_rng(seed) - coords_a = np.array([[0, 0], [1, 0], [0, 1]], dtype=float) + rng.normal(0, 0.05, (3, 2)) - coords_b = np.array([[0.5, 0.5], [1.5, 0.5], [0.5, 1.5]], dtype=float) + rng.normal(0, 0.05, (3, 2)) - coords_c = np.array([[10, 10], [11, 10], [10, 11]], dtype=float) + rng.normal(0, 0.05, (3, 2)) - - coords = np.vstack([coords_a, coords_b, coords_c]) + n = 9 + coords = rng.normal(size=(n, 2)) labels = pd.Categorical(["A"] * 3 + ["B"] * 3 + ["C"] * 3) df = pd.DataFrame({"x": coords[:, 0], "y": coords[:, 1], "cell_type": labels}) return SpatialData(points={"cells": PointsModel.parse(df)}) @@ -120,23 +116,6 @@ def test_red_green_less_distinct(self, cvd_type: str): assert _perceptual_distance_matrix(rgb, colorblind_type=cvd_type)[0, 1] < _perceptual_distance_matrix(rgb)[0, 1] -class TestSpatialInterlacement: - def test_interleaved_higher_than_separated(self): - coords = np.array([[0, 0], [1, 0], [0.5, 0.5], [1.5, 0.5], [10, 10], [11, 10]]) - mat = _spatial_interlacement(coords, np.array(["A", "B", "A", "B", "C", "C"]), ["A", "B", "C"], n_neighbors=3) - assert mat[0, 1] > mat[0, 2] - assert mat[0, 1] > mat[1, 2] - - def test_diagonal_is_zero(self): - mat = _spatial_interlacement(np.array([[0, 0], [1, 0], [0.5, 0.5]]), np.array(["A", "B", "A"]), ["A", "B"], 2) - np.testing.assert_allclose(np.diag(mat), 0) - - def test_symmetric(self): - rng = np.random.default_rng(42) - mat = _spatial_interlacement(rng.normal(size=(50, 2)), np.array(list("ABCDE") * 10), list("ABCDE"), 5) - np.testing.assert_allclose(mat, mat.T) - - class TestOptimizer: def test_single_category(self): assert list(_optimize_assignment(np.zeros((1, 1)), np.zeros((1, 1)))) == [0] @@ -196,11 +175,6 @@ def test_too_few_colors_raises(self): with pytest.raises(ValueError, match="needed"): make_palette(10, palette=["red", "blue"]) - @pytest.mark.parametrize("method", ["spaco", "spaco_colorblind"]) - def test_spaco_methods_raise(self, method: str): - with pytest.raises(ValueError, match="requires spatial data"): - make_palette(3, method=method) # type: ignore[arg-type] - def test_unknown_method_raises(self): with pytest.raises(ValueError, match="Unknown method"): make_palette(3, method="invalid") # type: ignore[arg-type] @@ -239,56 +213,17 @@ def test_named_palette_sources(self, clustered_sdata: SpatialData, palette: str) result = make_palette_from_data(clustered_sdata, "cells", "cell_type", palette=palette) assert isinstance(result, dict) and len(result) == 3 - @pytest.mark.parametrize( - "method", - ["contrast", "colorblind", "spaco", "spaco_colorblind", "spaco_deuteranopia"], - ) - def test_all_methods_return_valid_dict(self, clustered_sdata: SpatialData, method: str): + @pytest.mark.parametrize("method", ["contrast", "colorblind"]) + def test_optimization_methods_return_valid_dict(self, clustered_sdata: SpatialData, method: str): result = make_palette_from_data(clustered_sdata, "cells", "cell_type", method=method, seed=42) assert isinstance(result, dict) assert set(result.keys()) == {"A", "B", "C"} - def test_spaco_deterministic(self, clustered_sdata: SpatialData): - r1 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=42) - r2 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=42) - assert r1 == r2 - - def test_spaco_different_seeds_can_differ(self, clustered_sdata: SpatialData): - r1 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=0) - r2 = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", seed=999) - assert set(r1.keys()) == set(r2.keys()) - - def test_spaco_custom_palette_is_permutation(self, clustered_sdata: SpatialData): - colors = ["#ff0000", "#00ff00", "#0000ff"] - result = make_palette_from_data(clustered_sdata, "cells", "cell_type", method="spaco", palette=colors, seed=42) - assert set(result.values()) == {to_hex(to_rgb(c)) for c in colors} - - def test_spaco_single_category(self): - df = pd.DataFrame({"x": [0.0, 1.0], "y": [0.0, 1.0], "ct": pd.Categorical(["A", "A"])}) - sdata = SpatialData(points={"pts": PointsModel.parse(df)}) - result = make_palette_from_data(sdata, "pts", "ct", method="spaco", seed=0) - assert len(result) == 1 and "A" in result - - def test_spaco_nan_labels_filtered(self): - df = pd.DataFrame( - {"x": [0.0, 1.0, 0.0, 10.0], "y": [0.0, 0.0, 1.0, 10.0], "ct": pd.Categorical(["A", "B", "A", None])} - ) - sdata = SpatialData(points={"pts": PointsModel.parse(df)}) - result = make_palette_from_data(sdata, "pts", "ct", method="spaco", seed=0) - assert {"A", "B"} <= set(result.keys()) - def test_shapes_with_table(self, shapes_sdata: SpatialData): - result = make_palette_from_data(shapes_sdata, "my_shapes", "cell_type", method="spaco", seed=42) + result = make_palette_from_data(shapes_sdata, "my_shapes", "cell_type", seed=42) assert isinstance(result, dict) assert set(result.keys()) == {"X", "Y", "Z"} - def test_interleaved_get_distinct_colors(self): - sdata = _build_clustered_points_sdata(seed=0) - palette = ["#ff0000", "#ff1100", "#0000ff"] - result = make_palette_from_data(sdata, "cells", "cell_type", method="spaco", palette=palette, seed=0) - # A and B (interleaved) should not both get red-ish colors - assert result["A"] == "#0000ff" or result["B"] == "#0000ff" - # --------------------------------------------------------------------------- # Error cases