Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions tests/test_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,3 +616,142 @@ def test_secondary_not_modified(self) -> None:

# Secondary traces should still use original yaxis
assert secondary.data[0].yaxis == original_yaxis


class TestLegendVisibility:
"""Tests that combined figures preserve legend visibility."""

def test_overlay_single_trace_figures_with_names(self) -> None:
"""Overlay of named single-trace figures shows legend."""
da1 = xr.DataArray([1, 2, 3], dims=["x"], name="a")
da2 = xr.DataArray([4, 5, 6], dims=["x"], name="b")

fig1 = xpx(da1).line()
fig1.update_traces(name="Series A")
fig2 = xpx(da2).line()
fig2.update_traces(name="Series B")

combined = overlay(fig1, fig2)

assert combined.data[0].showlegend is True
assert combined.data[1].showlegend is True

def test_overlay_unnamed_traces_get_yaxis_title(self) -> None:
"""Overlay of unnamed traces derives names from y-axis titles."""
da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature")
da2 = xr.DataArray([4, 5, 6], dims=["x"], name="Pressure")

fig1 = xpx(da1).line()
fig2 = xpx(da2).line()

combined = overlay(fig1, fig2)

# Names derived from y-axis titles (DataArray names)
assert combined.data[0].name == "Temperature"
assert combined.data[1].name == "Pressure"
assert combined.data[0].showlegend is True
assert combined.data[1].showlegend is True

def test_overlay_same_name_disambiguated(self) -> None:
"""Overlay of figures with same y-axis title gets numeric suffix."""
da1 = xr.DataArray([1, 2, 3], dims=["x"], name="value")
da2 = xr.DataArray([4, 5, 6], dims=["x"], name="value")

fig1 = xpx(da1).line()
fig2 = xpx(da2).line()

combined = overlay(fig1, fig2)

assert combined.data[0].name == "value (1)"
assert combined.data[1].name == "value (2)"

def test_overlay_multi_trace_deduplicates_legend(self) -> None:
"""Overlay of multi-trace figures deduplicates shared legendgroups."""
da = xr.DataArray(
np.random.rand(10, 3),
dims=["x", "cat"],
coords={"cat": ["A", "B", "C"]},
)
fig1 = xpx(da).area()
fig2 = xpx(da).line()

combined = overlay(fig1, fig2)

# First occurrence of each legendgroup should show, duplicates hidden
from collections import defaultdict

groups: dict[str, list[bool]] = defaultdict(list)
for trace in combined.data:
lg = trace.legendgroup
groups[lg].append(trace.showlegend is True)

for lg, flags in groups.items():
assert flags.count(True) == 1, f"legendgroup {lg!r} has {flags.count(True)} visible"

def test_add_secondary_y_single_trace_with_names(self) -> None:
"""add_secondary_y of named single-trace figures shows legend."""
da1 = xr.DataArray([1, 2, 3], dims=["x"], name="temp")
da2 = xr.DataArray([100, 200, 300], dims=["x"], name="precip")

fig1 = xpx(da1).line()
fig1.update_traces(name="Temperature")
fig2 = xpx(da2).bar()
fig2.update_traces(name="Precipitation")

combined = add_secondary_y(fig1, fig2)

assert combined.data[0].showlegend is True
assert combined.data[1].showlegend is True

def test_overlay_faceted_legendgroup_dedup(self) -> None:
"""Faceted overlay keeps only one showlegend=True per legendgroup."""
da = xr.DataArray(
np.random.rand(10, 2, 2),
dims=["x", "cat", "facet"],
coords={"cat": ["A", "B"], "facet": ["left", "right"]},
)
fig1 = xpx(da).area(facet_col="facet")
fig2 = xpx(da).line(facet_col="facet")

combined = overlay(fig1, fig2)

# Check each legendgroup has at least one showlegend=True
from collections import defaultdict

groups: dict[str, list[bool]] = defaultdict(list)
for trace in combined.data:
lg = trace.legendgroup or ""
if lg:
groups[lg].append(trace.showlegend is True)

for lg, flags in groups.items():
assert any(flags), f"legendgroup {lg!r} has no showlegend=True trace"

def test_overlay_animation_frames_preserve_style(self) -> None:
"""Animation frame traces keep legend and color from fig.data."""
da = xr.DataArray(
np.random.rand(10, 3),
dims=["x", "time"],
coords={"time": [0, 1, 2]},
name="Population",
)
da_smooth = da.rolling(x=3, center=True).mean()
da_smooth.name = "Smoothed"

fig1 = xpx(da).bar(animation_frame="time")
fig1.update_traces(marker={"color": "steelblue"})
fig2 = xpx(da_smooth).line(animation_frame="time")
fig2.update_traces(line={"color": "red"})

combined = overlay(fig1, fig2)

for frame in combined.frames:
for i, ft in enumerate(frame.data):
src = combined.data[i]
assert ft.name == src.name
assert ft.showlegend == src.showlegend
assert ft.legendgroup == src.legendgroup
# Bar trace should keep steelblue
assert frame.data[0].marker.color == "steelblue"
# Line trace should keep red
assert frame.data[1].line.color == "red"
212 changes: 193 additions & 19 deletions xarray_plotly/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,171 @@
import plotly.graph_objects as go


def _get_yaxis_title(fig: go.Figure) -> str:
"""Extract the primary y-axis title text from a figure.

Args:
fig: A Plotly figure.

Returns:
The y-axis title text, or empty string if not set.
"""
try:
return fig.layout.yaxis.title.text or ""
except AttributeError:
return ""


def _ensure_legend_visibility(
combined: go.Figure,
source_figs: list[go.Figure],
trace_slices: list[slice],
) -> None:
"""Fix legend visibility on a combined figure.

Handles three problems that arise when combining Plotly Express figures:

1. **Unnamed traces** — PX sets ``name=""`` on single-trace (no color)
figures. We derive a name from each source figure's y-axis title.
2. **Hidden named traces** — PX sets ``showlegend=False`` on single-trace
figures. We ensure at least one trace per ``legendgroup`` (or each
ungrouped named trace) has ``showlegend=True``.
3. **Duplicate legend entries** — when two source figures share the same
``legendgroup`` names, we deduplicate so only the first trace per
group shows in the legend.

Args:
combined: The combined Plotly figure (mutated in place).
source_figs: The original source figures, in trace order.
trace_slices: Slices into ``combined.data`` for each source figure.
"""
from collections import defaultdict

# --- Step 1: label unnamed traces from source y-axis titles -----------
labels = [_get_yaxis_title(f) for f in source_figs]

# If all labels are the same, disambiguate
unique_labels = {lb for lb in labels if lb}
if len(unique_labels) == 1:
labels = [f"{labels[0]} ({i + 1})" for i in range(len(labels))]

Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
for label, sl in zip(labels, trace_slices, strict=False):
if not label:
continue
for trace in combined.data[sl]:
if not getattr(trace, "name", None):
trace.name = label
trace.legendgroup = label

# --- Step 2 & 3: fix showlegend per legendgroup -----------------------
grouped: dict[str, list[Any]] = defaultdict(list)
ungrouped: list[Any] = []

for trace in combined.data:
lg = getattr(trace, "legendgroup", None) or ""
if lg:
grouped[lg].append(trace)
else:
ungrouped.append(trace)

for traces in grouped.values():
has_visible = False
for t in traces:
if has_visible:
# Deduplicate: only first keeps showlegend
t.showlegend = False
elif getattr(t, "name", None):
t.showlegend = True
has_visible = True

# Ungrouped traces with a name should show in the legend
for trace in ungrouped:
if getattr(trace, "name", None):
trace.showlegend = True

# --- Step 4: propagate style properties to animation frame traces ------
# When Plotly animates, frame trace data overwrites fig.data properties.
# PX frame traces carry name="", showlegend=False and default colors,
# discarding any styling the user applied via update_traces() before
# combining. Propagate display properties from fig.data into every frame.
_STYLE_ATTRS = ("name", "legendgroup", "showlegend", "marker", "line", "opacity")
for frame in combined.frames or []:
for i, frame_trace in enumerate(frame.data):
if i < len(combined.data):
src = combined.data[i]
for attr in _STYLE_ATTRS:
src_val = getattr(src, attr, None)
if src_val is not None:
setattr(frame_trace, attr, src_val)


def _fix_animation_axis_ranges(fig: go.Figure) -> None:
"""Set axis ranges to encompass data across all animation frames.

Plotly.js computes autorange from ``fig.data`` only and does not
recalculate during animation. When different frames have very different
data ranges (e.g. population of Brazil vs China), values can go off-screen.
This function computes the global min/max for each axis across all frames
and sets explicit ranges on the layout.

Only numeric axes are handled; categorical/date axes are left to autorange.

Args:
fig: A Plotly figure with animation frames (mutated in place).
"""
import numpy as np

if not fig.frames:
return

from collections import defaultdict

# Collect numeric y-values per axis across all traces (fig.data + frames)
y_by_axis: dict[str, list[float]] = defaultdict(list)
x_by_axis: dict[str, list[float]] = defaultdict(list)

for trace in _iter_all_traces(fig):
yaxis = getattr(trace, "yaxis", None) or "y"
xaxis = getattr(trace, "xaxis", None) or "x"

y = getattr(trace, "y", None)
if y is not None:
try:
arr = np.asarray(y, dtype=float)
finite = arr[np.isfinite(arr)]
if len(finite):
y_by_axis[yaxis].extend(finite.tolist())
except (ValueError, TypeError):
pass # Non-numeric (categorical) — skip

x = getattr(trace, "x", None)
if x is not None:
try:
arr = np.asarray(x, dtype=float)
finite = arr[np.isfinite(arr)]
if len(finite):
x_by_axis[xaxis].extend(finite.tolist())
except (ValueError, TypeError):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
pass

# Apply ranges to layout
for axis_ref, values in y_by_axis.items():
if not values:
continue
lo, hi = min(values), max(values)
pad = (hi - lo) * 0.05 or 1 # 5% padding
layout_prop = "yaxis" if axis_ref == "y" else f"yaxis{axis_ref[1:]}"
fig.layout[layout_prop].range = [lo - pad, hi + pad]

for axis_ref, values in x_by_axis.items():
if not values:
continue
lo, hi = min(values), max(values)
pad = (hi - lo) * 0.05 or 1
layout_prop = "xaxis" if axis_ref == "x" else f"xaxis{axis_ref[1:]}"
fig.layout[layout_prop].range = [lo - pad, hi + pad]
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _iter_all_traces(fig: go.Figure) -> Iterator[Any]:
"""Iterate over all traces in a figure, including animation frames.

Expand Down Expand Up @@ -194,17 +359,11 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure:
_validate_compatible_structure(base, overlay)
_validate_animation_compatibility(base, overlay)

# Create new figure with base's layout
combined = go.Figure(layout=copy.deepcopy(base.layout))

# Add all traces from base
for trace in base.data:
combined.add_trace(copy.deepcopy(trace))

# Add all traces from overlays
# Create new figure with base's layout and all traces
all_traces = [copy.deepcopy(t) for t in base.data]
for overlay in overlays:
for trace in overlay.data:
combined.add_trace(copy.deepcopy(trace))
all_traces.extend(copy.deepcopy(t) for t in overlay.data)
combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout))

# Handle animation frames
if base.frames:
Expand All @@ -213,6 +372,17 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure:
merged_frames = _merge_frames(base, list(overlays), base_trace_count, overlay_trace_counts)
combined.frames = merged_frames

# Build trace slices for legend fix
source_figs = [base, *overlays]
slices: list[slice] = []
offset = 0
for fig in source_figs:
n = len(fig.data)
slices.append(slice(offset, offset + n))
offset += n

_ensure_legend_visibility(combined, source_figs, slices)
_fix_animation_axis_ranges(combined)
return combined


Expand Down Expand Up @@ -315,19 +485,15 @@ def add_secondary_y(
rightmost_x = max(x_for_y.values(), key=lambda x: int(x[1:]) if x != "x" else 1)
rightmost_primary_y = next(y for y, x in x_for_y.items() if x == rightmost_x)

# Create new figure with base's layout
combined = go.Figure(layout=copy.deepcopy(base.layout))

# Add all traces from base (primary y-axis)
for trace in base.data:
combined.add_trace(copy.deepcopy(trace))

# Add all traces from secondary, remapped to secondary y-axes
# Build all traces: base (primary) + secondary (remapped to secondary y-axes)
all_traces = [copy.deepcopy(t) for t in base.data]
for trace in secondary.data:
trace_copy = copy.deepcopy(trace)
original_yaxis = getattr(trace_copy, "yaxis", None) or "y"
trace_copy.yaxis = y_mapping[original_yaxis]
combined.add_trace(trace_copy)
all_traces.append(trace_copy)

combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout))

# Get the rightmost secondary y-axis name for linking
rightmost_secondary_y = y_mapping[rightmost_primary_y]
Expand Down Expand Up @@ -368,6 +534,14 @@ def add_secondary_y(
merged_frames = _merge_secondary_y_frames(base, secondary, y_mapping)
combined.frames = merged_frames

base_n = len(base.data)
sec_n = len(secondary.data)
_ensure_legend_visibility(
combined,
[base, secondary],
[slice(0, base_n), slice(base_n, base_n + sec_n)],
)
_fix_animation_axis_ranges(combined)
return combined


Expand Down
Loading