Skip to content

Commit 1f10b21

Browse files
authored
Fix RGBA images rendered with categorical coloring (#563)
1 parent 5bb08ca commit 1f10b21

5 files changed

Lines changed: 307 additions & 10 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def render_points(
505505

506506
return sdata
507507

508-
@_deprecation_alias(elements="element", quantiles_for_norm="percentiles_for_norm", version="version 0.3.0")
508+
@_deprecation_alias(elements="element", version="version 0.3.0")
509509
def render_images(
510510
self,
511511
element: str | None = None,

src/spatialdata_plot/pl/render.py

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
_render_ds_outlines,
4040
)
4141
from spatialdata_plot.pl.render_params import (
42+
CmapParams,
4243
Color,
4344
ColorbarSpec,
4445
FigParams,
@@ -1018,6 +1019,63 @@ def _render_points(
10181019
)
10191020

10201021

1022+
def _normalize_dtype_to_float(arr: np.ndarray) -> np.ndarray:
1023+
"""Normalize an array to float64 in [0, 1] for display with matplotlib.
1024+
1025+
Intended for RGB/RGBA image data where negative values are not meaningful.
1026+
1027+
- uint8 → divide by 255
1028+
- other unsigned int → divide by dtype max
1029+
- signed int → divide by dtype max, clip negatives to 0
1030+
- float already in [0, 1] → pass through
1031+
- float outside [0, 1] → global auto-range (preserves relative balance across channels)
1032+
"""
1033+
if arr.dtype == np.uint8:
1034+
return arr.astype(np.float64) / 255.0
1035+
if arr.dtype.kind == "u":
1036+
return arr.astype(np.float64) / np.iinfo(arr.dtype).max
1037+
if arr.dtype.kind == "i":
1038+
return np.clip(arr.astype(np.float64) / np.iinfo(arr.dtype).max, 0, 1)
1039+
# Float: if already in [0, 1], keep as-is; otherwise auto-range globally
1040+
arr_f: np.ndarray = arr.astype(np.float64)
1041+
vmin, vmax = arr_f.min(), arr_f.max()
1042+
if vmin >= 0.0 and vmax <= 1.0:
1043+
return arr_f
1044+
if vmin == vmax:
1045+
return np.zeros_like(arr_f)
1046+
logger.info(
1047+
"Float RGB image has values outside [0, 1] (range [%.3f, %.3f]); "
1048+
"auto-ranging globally. Pass an explicit 'norm' to control contrast.",
1049+
vmin,
1050+
vmax,
1051+
)
1052+
result: np.ndarray = (arr_f - vmin) / (vmax - vmin)
1053+
return result
1054+
1055+
1056+
def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]:
1057+
"""Check if channel coordinates indicate an RGB(A) image.
1058+
1059+
Checks case-insensitively whether channel names are {r, g, b} or {r, g, b, a}.
1060+
1061+
Parameters
1062+
----------
1063+
channel_coords
1064+
The channel coordinate values from the image.
1065+
1066+
Returns
1067+
-------
1068+
tuple[bool, bool]
1069+
(is_rgb, has_alpha) — whether the image is RGB and whether it includes an alpha channel.
1070+
"""
1071+
names = {str(c).lower() for c in channel_coords}
1072+
if names == {"r", "g", "b", "a"} and len(channel_coords) == 4:
1073+
return True, True
1074+
if names == {"r", "g", "b"} and len(channel_coords) == 3:
1075+
return True, False
1076+
return False, False
1077+
1078+
10211079
def _render_images(
10221080
sdata: sd.SpatialData,
10231081
render_params: ImageRenderParams,
@@ -1083,6 +1141,50 @@ def _render_images(
10831141

10841142
_, trans_data = _prepare_transformation(img, coordinate_system, ax)
10851143

1144+
# Detect RGB(A) images by channel names — skip when user overrides with palette/cmap
1145+
is_rgb, has_alpha = _is_rgb_image(channels)
1146+
has_explicit_cmap = (
1147+
isinstance(render_params.cmap_params, CmapParams) and not render_params.cmap_params.cmap_is_default
1148+
)
1149+
if is_rgb and palette is None and not got_multiple_cmaps and not has_explicit_cmap:
1150+
coord_map = {str(c).lower(): c for c in channels}
1151+
ordered = [coord_map[ch] for ch in ("r", "g", "b")]
1152+
1153+
# Apply norm per channel if user provided one, otherwise normalize by dtype
1154+
user_norm = (
1155+
render_params.cmap_params.norm
1156+
if isinstance(render_params.cmap_params, CmapParams)
1157+
and isinstance(render_params.cmap_params.norm, Normalize)
1158+
and (render_params.cmap_params.norm.vmin is not None or render_params.cmap_params.norm.vmax is not None)
1159+
else None
1160+
)
1161+
1162+
if user_norm is not None:
1163+
rgb_layers = []
1164+
for ch in ordered:
1165+
ch_norm = copy(user_norm)
1166+
rgb_layers.append(np.clip(ch_norm(img.sel(c=ch).values).astype(np.float64), 0, 1))
1167+
stacked = np.stack(rgb_layers, axis=-1)
1168+
else:
1169+
stacked = _normalize_dtype_to_float(np.moveaxis(img.sel(c=ordered).values, 0, -1))
1170+
1171+
show_kwargs: dict[str, Any] = {"zorder": render_params.zorder}
1172+
1173+
if has_alpha and render_params.alpha == 1.0:
1174+
alpha_layer = _normalize_dtype_to_float(img.sel(c=coord_map["a"]).values)
1175+
stacked = np.concatenate([stacked, alpha_layer[..., np.newaxis]], axis=-1)
1176+
else:
1177+
show_kwargs["alpha"] = render_params.alpha
1178+
if has_alpha:
1179+
logger.info(
1180+
"Image has an alpha channel, but an explicit 'alpha' value was provided. "
1181+
"Using the user-specified alpha=%.2f instead of the per-pixel alpha from the data.",
1182+
render_params.alpha,
1183+
)
1184+
1185+
_ax_show_and_transform(stacked, trans_data, ax, **show_kwargs)
1186+
return
1187+
10861188
# 1) Image has only 1 channel
10871189
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
10881190
layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze()
@@ -1138,13 +1240,16 @@ def _render_images(
11381240
else:
11391241
ch_norm = render_params.cmap_params.norm
11401242

1141-
if ch_norm is not None:
1142-
layers[ch] = ch_norm(layers[ch])
1243+
# Auto-ranging norms are stateful — copy so each channel normalizes independently
1244+
if isinstance(ch_norm, Normalize) and (ch_norm.vmin is None or ch_norm.vmax is None):
1245+
ch_norm = copy(ch_norm)
1246+
1247+
layers[ch] = ch_norm(layers[ch])
11431248

11441249
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
11451250
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
11461251
if render_params.cmap_params.cmap_is_default: # -> use RGB
1147-
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
1252+
stacked = np.clip(np.stack([layers[ch] for ch in layers], axis=-1), 0, 1)
11481253
else: # -> use given cmap for each channel
11491254
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
11501255
stacked = (
@@ -1182,15 +1287,15 @@ def _render_images(
11821287
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
11831288
0,
11841289
).sum(0)
1185-
colored = colored[:, :, :3]
1290+
colored = np.clip(colored[:, :, :3], 0, 1)
11861291
elif n_channels == 3:
11871292
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
11881293
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
11891294
colored = np.stack(
11901295
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
11911296
0,
11921297
).sum(0)
1193-
colored = colored[:, :, :3]
1298+
colored = np.clip(colored[:, :, :3], 0, 1)
11941299
else:
11951300
if isinstance(render_params.cmap_params, list):
11961301
cmap_is_default = render_params.cmap_params[0].cmap_is_default
@@ -1241,7 +1346,7 @@ def _render_images(
12411346

12421347
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)]
12431348
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
1244-
colored = colored[:, :, :3]
1349+
colored = np.clip(colored[:, :, :3], 0, 1)
12451350

12461351
_ax_show_and_transform(
12471352
colored,

src/spatialdata_plot/pl/render_params.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ class ImageRenderParams:
265265
channel: list[str] | list[int] | int | str | None = None
266266
palette: ListedColormap | list[str] | None = None
267267
alpha: float = 1.0
268-
percentiles_for_norm: tuple[float | None, float | None] = (None, None)
269268
scale: str | None = None
270269
zorder: int = 0
271270
colorbar: bool | str | None = "auto"

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,7 @@ def _prepare_cmap_norm(
628628

629629
assert isinstance(cmap, Colormap), f"Invalid type of `cmap`: {type(cmap)}, expected `Colormap`."
630630

631-
if norm is None:
632-
norm = Normalize(vmin=None, vmax=None, clip=False)
631+
norm = Normalize(vmin=None, vmax=None, clip=False) if norm is None else copy(norm)
633632

634633
cmap.set_bad(na_color.get_hex_with_alpha())
635634

0 commit comments

Comments
 (0)