Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def render_shapes(
method: str | None = None,
table_name: str | None = None,
table_layer: str | None = None,
gene_symbols: str | None = None,
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None,
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
Expand Down Expand Up @@ -263,6 +264,10 @@ def render_shapes(
table_layer: str | None
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
:attr:`sdata.table.X` is used for coloring.
gene_symbols: str | None
Column name in :attr:`sdata.table.var` to use for looking up ``color``. Use this when
``var_names`` are e.g. ENSEMBL IDs but you want to refer to genes by their symbols stored
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.
shape: Literal["circle", "hex", "visium_hex", "square"] | None
If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is
specified, the shapes are converted to a circle/hexagon/square before rendering. If "visium_hex" is
Expand Down Expand Up @@ -313,6 +318,7 @@ def render_shapes(
ds_reduction=kwargs.get("datashader_reduction"),
colorbar=colorbar,
colorbar_params=colorbar_params,
gene_symbols=gene_symbols,
)

sdata = self._copy()
Expand Down Expand Up @@ -370,6 +376,7 @@ def render_points(
method: str | None = None,
table_name: str | None = None,
table_layer: str | None = None,
gene_symbols: str | None = None,
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -434,6 +441,10 @@ def render_points(
table_layer: str | None
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
:attr:`sdata.table.X` is used for coloring.
gene_symbols: str | None
Column name in :attr:`sdata.table.var` to use for looking up ``color``. Use this when
``var_names`` are e.g. ENSEMBL IDs but you want to refer to genes by their symbols stored
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.

**kwargs : Any
Additional arguments for customization. This can include:
Expand Down Expand Up @@ -467,6 +478,7 @@ def render_points(
ds_reduction=kwargs.get("datashader_reduction"),
colorbar=colorbar,
colorbar_params=colorbar_params,
gene_symbols=gene_symbols,
)

if method is not None:
Expand Down Expand Up @@ -706,6 +718,7 @@ def render_labels(
colorbar_params: dict[str, object] | None = None,
table_name: str | None = None,
table_layer: str | None = None,
gene_symbols: str | None = None,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -775,6 +788,10 @@ def render_labels(
table_layer: str | None
Layer of the AnnData table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None,
:attr:`sdata.table.X` of the default table is used for coloring.
gene_symbols: str | None
Column name in :attr:`sdata.table.var` to use for looking up ``color``. Use this when
``var_names`` are e.g. ENSEMBL IDs but you want to refer to genes by their symbols stored
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.
kwargs
Additional arguments to be passed to cmap and norm.

Expand Down Expand Up @@ -803,6 +820,7 @@ def render_labels(
colorbar_params=colorbar_params,
table_name=table_name,
table_layer=table_layer,
gene_symbols=gene_symbols,
)

sdata = self._copy()
Expand Down
56 changes: 49 additions & 7 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,6 +2550,7 @@ def _validate_label_render_params(
table_layer: str | None,
colorbar: bool | str | None,
colorbar_params: dict[str, object] | None,
gene_symbols: str | None = None,
) -> dict[str, dict[str, Any]]:
param_dict: dict[str, Any] = {
"sdata": sdata,
Expand Down Expand Up @@ -2593,7 +2594,7 @@ def _validate_label_render_params(
element_params[el]["col_for_color"] = None
if (col_for_color := param_dict["col_for_color"]) is not None:
col_for_color, table_name = _validate_col_for_column_table(
sdata, el, col_for_color, param_dict["table_name"], labels=True
sdata, el, col_for_color, param_dict["table_name"], labels=True, gene_symbols=gene_symbols
)
element_params[el]["table_name"] = table_name
element_params[el]["col_for_color"] = col_for_color
Expand Down Expand Up @@ -2621,6 +2622,7 @@ def _validate_points_render_params(
ds_reduction: str | None,
colorbar: bool | str | None,
colorbar_params: dict[str, object] | None,
gene_symbols: str | None = None,
) -> dict[str, dict[str, Any]]:
param_dict: dict[str, Any] = {
"sdata": sdata,
Expand Down Expand Up @@ -2660,7 +2662,7 @@ def _validate_points_render_params(
col_for_color = param_dict["col_for_color"]
if col_for_color is not None:
col_for_color, table_name = _validate_col_for_column_table(
sdata, el, col_for_color, param_dict["table_name"]
sdata, el, col_for_color, param_dict["table_name"], gene_symbols=gene_symbols
)
element_params[el]["table_name"] = table_name
element_params[el]["col_for_color"] = col_for_color
Expand Down Expand Up @@ -2694,6 +2696,7 @@ def _validate_shape_render_params(
ds_reduction: str | None,
colorbar: bool | str | None,
colorbar_params: dict[str, object] | None,
gene_symbols: str | None = None,
) -> dict[str, dict[str, Any]]:
param_dict: dict[str, Any] = {
"sdata": sdata,
Expand Down Expand Up @@ -2743,7 +2746,7 @@ def _validate_shape_render_params(
col_for_color = param_dict["col_for_color"]
if col_for_color is not None:
col_for_color, table_name = _validate_col_for_column_table(
sdata, el, col_for_color, param_dict["table_name"]
sdata, el, col_for_color, param_dict["table_name"], gene_symbols=gene_symbols
)
element_params[el]["table_name"] = table_name
element_params[el]["col_for_color"] = col_for_color
Expand All @@ -2757,12 +2760,38 @@ def _validate_shape_render_params(
return element_params


def _resolve_gene_symbols(
adata: AnnData,
col_for_color: str,
gene_symbols: str,
) -> str:
"""Resolve a gene symbol to its var_name using an alternate var column.

Mimics scanpy's ``gene_symbols`` behaviour: look up *col_for_color* in
``adata.var[gene_symbols]`` and return the corresponding ``var_name``
(i.e. the var index value).
"""
if gene_symbols not in adata.var.columns:
raise KeyError(f"Column '{gene_symbols}' not found in `adata.var`. Cannot use it as `gene_symbols` lookup.")
mask = adata.var[gene_symbols] == col_for_color
if not mask.any():
raise KeyError(f"'{col_for_color}' not found in `adata.var['{gene_symbols}']`.")
n_matches = mask.sum()
if n_matches > 1:
logger.warning(
f"Gene symbol '{col_for_color}' maps to {n_matches} var_names in column '{gene_symbols}'. "
f"Using the first match: '{adata.var.index[mask][0]}'."
)
return str(adata.var.index[mask][0])


def _validate_col_for_column_table(
sdata: SpatialData,
element_name: str,
col_for_color: str | None,
table_name: str | None,
labels: bool = False,
gene_symbols: str | None = None,
) -> tuple[str | None, str | None]:
if col_for_color is None:
return None, None
Expand All @@ -2775,9 +2804,13 @@ def _validate_col_for_column_table(
logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.")
raise KeyError(f"Table '{table_name}' does not annotate element '{element_name}'.")
if col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names:
raise KeyError(
f"Column '{col_for_color}' not found in obs/var of table '{table_name}' for element '{element_name}'."
)
if gene_symbols is not None:
col_for_color = _resolve_gene_symbols(sdata[table_name], col_for_color, gene_symbols)
else:
raise KeyError(
f"Column '{col_for_color}' not found in obs/var of table '{table_name}' "
f"for element '{element_name}'."
)
else:
tables = get_element_annotators(sdata, element_name)
if len(tables) == 0:
Expand All @@ -2787,9 +2820,16 @@ def _validate_col_for_column_table(
"Please ensure the element is annotated by at least one table."
)
# Now check which tables contain the column
resolved_var_name: str | None = None
for annotates in tables.copy():
if col_for_color not in sdata[annotates].obs.columns and col_for_color not in sdata[annotates].var_names:
tables.remove(annotates)
if gene_symbols is not None:
try:
resolved_var_name = _resolve_gene_symbols(sdata[annotates], col_for_color, gene_symbols)
except KeyError:
tables.remove(annotates)
else:
tables.remove(annotates)
if len(tables) == 0:
raise KeyError(
f"Unable to locate color key '{col_for_color}' for element '{element_name}'. "
Expand All @@ -2798,6 +2838,8 @@ def _validate_col_for_column_table(
table_name = next(iter(tables))
if len(tables) > 1:
logger.warning(f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.")
if resolved_var_name is not None:
col_for_color = resolved_var_name
return col_for_color, table_name


Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ def test_plot_can_annotate_labels_with_nan_in_table_X_continuous(self, sdata_blo
sdata_blobs["table"].X[0:5, 0] = np.nan
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()

def test_plot_can_color_labels_by_gene_symbols(self, sdata_blobs: SpatialData):
"""Color labels by gene symbol alias instead of var_name (#247)."""
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
sdata_blobs.pl.render_labels(
"blobs_labels", color="GeneA", table_name="table", gene_symbols="gene_symbol"
).pl.show()


def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
# Work on an independent copy since we mutate tables
Expand Down
20 changes: 20 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,26 @@ def test_plot_sampled_points_categorical_color_datashader(self):
"""Regression test for #358: .sample() must not shuffle categorical colors."""
self._make_sampled_sdata().pl.render_points("pts", color="cluster", method="datashader").pl.show()

def test_plot_can_color_points_by_gene_symbols(self, sdata_blobs: SpatialData):
"""Color points by gene symbol alias instead of var_name (#247)."""
rng = get_standard_RNG()
pts = sdata_blobs["blobs_points"].compute()
n_obs = len(pts)
# Assign unique instance IDs to each point
pts["instance_id"] = np.arange(n_obs)
sdata_blobs["blobs_points"] = PointsModel.parse(pts)
adata = AnnData(
X=rng.random((n_obs, 3)),
var=pd.DataFrame({"gene_symbol": ["GeneA", "GeneB", "GeneC"]}, index=["f0", "f1", "f2"]),
)
adata.obs["region"] = pd.Categorical(["blobs_points"] * n_obs)
adata.obs["instance_id"] = np.arange(n_obs)
table = TableModel.parse(adata, region="blobs_points", region_key="region", instance_key="instance_id")
sdata_blobs["table"] = table
sdata_blobs.pl.render_points(
"blobs_points", color="GeneA", table_name="table", gene_symbols="gene_symbol", size=10
).pl.show()


def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
"""When no elements match the groups, the plot should render without error."""
Expand Down
38 changes: 38 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,44 @@ def test_plot_groups_na_color_none_filters_shapes_datashader(self, sdata_blobs:
ax=axs[1], title="default (filtered)"
)

def test_plot_can_color_shapes_by_gene_symbols(self, sdata_blobs: SpatialData):
"""Color shapes by gene symbol alias instead of var_name (#247)."""
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
sdata_blobs.pl.render_shapes(
"blobs_circles", color="GeneA", table_name="table", gene_symbols="gene_symbol"
).pl.show()


def test_gene_symbols_auto_detect_table(sdata_blobs: SpatialData):
"""gene_symbols resolves correctly without explicit table_name (#247)."""
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
# No table_name — auto-detect path
sdata_blobs.pl.render_shapes("blobs_circles", color="GeneA", gene_symbols="gene_symbol").pl.show()
plt.close("all")


def test_gene_symbols_missing_symbol_raises(sdata_blobs: SpatialData):
"""gene_symbols raises KeyError when the symbol is not found (#247)."""
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
sdata_blobs["table"].var["gene_symbol"] = ["GeneA", "GeneB", "GeneC"]
with pytest.raises(KeyError, match="Unable to locate color key 'NoSuchGene'"):
sdata_blobs.pl.render_shapes("blobs_circles", color="NoSuchGene", gene_symbols="gene_symbol").pl.show()


def test_gene_symbols_missing_column_raises(sdata_blobs: SpatialData):
"""gene_symbols raises KeyError when the var column doesn't exist (#247)."""
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles"
with pytest.raises(KeyError, match="not found in `adata.var`"):
sdata_blobs.pl.render_shapes(
"blobs_circles", color="GeneA", table_name="table", gene_symbols="nonexistent_col"
).pl.show()


def test_groups_na_color_none_no_match_shapes(sdata_blobs: SpatialData):
"""When no elements match the groups, the plot should render without error."""
Expand Down
Loading