Skip to content

Commit c07d2a6

Browse files
committed
Refactor and add tests
1 parent 8121711 commit c07d2a6

2 files changed

Lines changed: 502 additions & 58 deletions

File tree

ultraplot/figure.py

Lines changed: 100 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,20 +1164,14 @@ def _snap_axes_to_pixel_grid(self, renderer) -> None:
11641164
which="both",
11651165
)
11661166

1167-
def _align_aspect_constrained_axes(self, *, tol: float = 1e-9) -> None:
1167+
def _find_aspect_constrained_spans(self, axes, *, tol=1e-9):
11681168
"""
1169-
Propagate aspect-constrained spanning axes boxes across sibling rows/columns.
1169+
Identify spanning axes whose aspect constraint caused matplotlib to
1170+
shrink them inside their gridspec slot.
11701171
1171-
When a fixed-aspect subplot spans multiple rows or columns, matplotlib shrinks
1172-
just that axes inside its gridspec slot. In layouts like ``[[1, 2], [1, 3]]``
1173-
this leaves the adjacent stack slightly taller or wider than the spanning axes.
1174-
Here we remap the sibling subplot slots onto the aspect-constrained box so the
1175-
overall geometry stays aligned.
1172+
Returns a list of ``(axis, start, stop, slot, pos, ref_ax)`` tuples
1173+
where *axis* is ``'y'`` for row-spanning or ``'x'`` for column-spanning.
11761174
"""
1177-
axes = list(self._iter_axes(hidden=False, children=False, panels=False))
1178-
if not axes:
1179-
return
1180-
11811175
spans = []
11821176
for ax in axes:
11831177
try:
@@ -1189,7 +1183,7 @@ def _align_aspect_constrained_axes(self, *, tol: float = 1e-9) -> None:
11891183
row1, row2, col1, col2 = ss._get_rows_columns()
11901184
slot = ss.get_position(self)
11911185
pos = ax.get_position(original=False)
1192-
except Exception:
1186+
except (AttributeError, TypeError):
11931187
continue
11941188

11951189
if row2 > row1 and (
@@ -1202,7 +1196,13 @@ def _align_aspect_constrained_axes(self, *, tol: float = 1e-9) -> None:
12021196
or abs((pos.x0 + pos.width) - (slot.x0 + slot.width)) > tol
12031197
):
12041198
spans.append(("x", col1, col2, slot, pos, ax))
1199+
return spans
12051200

1201+
def _remap_axes_to_span(self, axes, spans, *, tol=1e-9):
1202+
"""
1203+
Remap auto-aspect sibling axes so they align with the
1204+
aspect-constrained bounds described by *spans*.
1205+
"""
12061206
for axis, start, stop, slot, pos, ref_ax in spans:
12071207
slot0 = slot.y0 if axis == "y" else slot.x0
12081208
slotsize = slot.height if axis == "y" else slot.width
@@ -1226,7 +1226,7 @@ def _align_aspect_constrained_axes(self, *, tol: float = 1e-9) -> None:
12261226
if col1 < start or col2 > stop:
12271227
continue
12281228
old = ss.get_position(self)
1229-
except Exception:
1229+
except (AttributeError, TypeError):
12301230
continue
12311231

12321232
if axis == "y":
@@ -1243,6 +1243,22 @@ def _align_aspect_constrained_axes(self, *, tol: float = 1e-9) -> None:
12431243
bounds = [new0, old.y0, new1 - new0, old.height]
12441244
ax.set_position(bounds, which="both")
12451245

1246+
def _align_aspect_constrained_axes(self, *, tol: float = 1e-9) -> None:
1247+
"""
1248+
Propagate aspect-constrained spanning axes boxes across sibling rows/columns.
1249+
1250+
When a fixed-aspect subplot spans multiple rows or columns, matplotlib shrinks
1251+
just that axes inside its gridspec slot. In layouts like ``[[1, 2], [1, 3]]``
1252+
this leaves the adjacent stack slightly taller or wider than the spanning axes.
1253+
Here we remap the sibling subplot slots onto the aspect-constrained box so the
1254+
overall geometry stays aligned.
1255+
"""
1256+
axes = list(self._iter_axes(hidden=False, children=False, panels=False))
1257+
if not axes:
1258+
return
1259+
spans = self._find_aspect_constrained_spans(axes, tol=tol)
1260+
self._remap_axes_to_span(axes, spans, tol=tol)
1261+
12461262
def _share_ticklabels(self, *, axis: str) -> None:
12471263
"""
12481264
Tick label sharing is determined at the figure level. While
@@ -2641,6 +2657,59 @@ def _align_super_title(self, renderer):
26412657
y = y_target - y_bbox
26422658
self._suptitle.set_position((x, y))
26432659

2660+
@staticmethod
2661+
def _deduplicate_axes(axes):
2662+
"""
2663+
Resolve panel parents and remove duplicates, preserving order.
2664+
"""
2665+
seen = set()
2666+
unique = []
2667+
for ax in axes:
2668+
ax = ax._panel_parent or ax
2669+
ax_id = id(ax)
2670+
if ax_id not in seen:
2671+
seen.add(ax_id)
2672+
unique.append(ax)
2673+
return unique
2674+
2675+
@staticmethod
2676+
def _normalize_title_alignment(loc):
2677+
"""
2678+
Convert a *loc* string to a horizontal alignment for ``Text.set_ha``.
2679+
"""
2680+
align = _translate_loc(loc, "text")
2681+
match align:
2682+
case "left" | "outer left" | "upper left" | "lower left":
2683+
return "left"
2684+
case "center" | "upper center" | "lower center":
2685+
return "center"
2686+
case "right" | "outer right" | "upper right" | "lower right":
2687+
return "right"
2688+
case _:
2689+
raise ValueError(f"Invalid shared subplot title location {loc!r}.")
2690+
2691+
@staticmethod
2692+
def _resolve_title_props(fontdict, kwargs):
2693+
"""
2694+
Build the property dict for a title from rc defaults, *fontdict*,
2695+
and extra *kwargs*.
2696+
"""
2697+
kw = rc.fill(
2698+
{
2699+
"size": "title.size",
2700+
"weight": "title.weight",
2701+
"color": "title.color",
2702+
"family": "font.family",
2703+
},
2704+
context=True,
2705+
)
2706+
if "color" in kw and kw["color"] == "auto":
2707+
del kw["color"]
2708+
if fontdict:
2709+
kw.update(fontdict)
2710+
kw.update(kwargs)
2711+
return kw
2712+
26442713
def _update_subset_title(
26452714
self,
26462715
axes: Iterable[paxes.Axes],
@@ -2673,47 +2742,17 @@ def _update_subset_title(
26732742
if not axes:
26742743
raise ValueError("Need at least one axes to create a shared subplot title.")
26752744

2676-
seen = set()
2677-
unique_axes = []
2678-
for ax in axes:
2679-
ax = ax._panel_parent or ax
2680-
ax_id = id(ax)
2681-
if ax_id in seen:
2682-
continue
2683-
seen.add(ax_id)
2684-
unique_axes.append(ax)
2685-
axes = unique_axes
2745+
axes = self._deduplicate_axes(axes)
26862746
if len(axes) < 2:
26872747
return axes[0].set_title(
26882748
title, fontdict=fontdict, loc=loc, pad=pad, y=y, **kwargs
26892749
)
26902750

26912751
key = tuple(sorted(id(ax) for ax in axes))
26922752
group = self._subset_title_dict.get(key)
2693-
kw = rc.fill(
2694-
{
2695-
"size": "title.size",
2696-
"weight": "title.weight",
2697-
"color": "title.color",
2698-
"family": "font.family",
2699-
},
2700-
context=True,
2701-
)
2702-
if "color" in kw and kw["color"] == "auto":
2703-
del kw["color"]
2704-
if fontdict:
2705-
kw.update(fontdict)
2706-
kw.update(kwargs)
2707-
align = _translate_loc(loc, "text")
2708-
match align:
2709-
case "left" | "outer left" | "upper left" | "lower left":
2710-
align = "left"
2711-
case "center" | "upper center" | "lower center":
2712-
align = "center"
2713-
case "right" | "outer right" | "upper right" | "lower right":
2714-
align = "right"
2715-
case _:
2716-
raise ValueError(f"Invalid shared subplot title location {loc!r}.")
2753+
kw = self._resolve_title_props(fontdict, kwargs)
2754+
align = self._normalize_title_alignment(loc)
2755+
27172756
if group is None:
27182757
artist = self.text(
27192758
0.5,
@@ -2739,6 +2778,16 @@ def _update_subset_title(
27392778
artist.update(kw)
27402779
return artist
27412780

2781+
def _visible_subset_group_axes(self, group):
2782+
"""
2783+
Return visible axes from a subset-title group that belong to this figure.
2784+
"""
2785+
return [
2786+
ax
2787+
for ax in group["axes"]
2788+
if ax is not None and ax.figure is self and ax.get_visible()
2789+
]
2790+
27422791
def _get_subset_title_bbox(
27432792
self, ax: paxes.Axes, renderer
27442793
) -> mtransforms.Bbox | None:
@@ -2757,15 +2806,12 @@ def _get_subset_title_bbox(
27572806
if not artist.get_visible() or not artist.get_text():
27582807
continue
27592808
axs = [
2760-
group_ax._panel_parent or group_ax
2761-
for group_ax in group["axes"]
2762-
if group_ax is not None
2763-
and group_ax.figure is self
2764-
and group_ax.get_visible()
2809+
a._panel_parent or a
2810+
for a in self._visible_subset_group_axes(group)
27652811
]
27662812
if not axs or ax not in axs:
27672813
continue
2768-
top = min(group_ax._range_subplotspec("y")[0] for group_ax in axs)
2814+
top = min(a._range_subplotspec("y")[0] for a in axs)
27692815
if ax._range_subplotspec("y")[0] == top:
27702816
bboxes.append(artist.get_window_extent(renderer))
27712817
return mtransforms.Bbox.union(bboxes) if bboxes else None
@@ -2777,11 +2823,7 @@ def _align_subset_titles(self, renderer):
27772823
for key in list(self._subset_title_dict):
27782824
group = self._subset_title_dict[key]
27792825
artist = group["artist"]
2780-
axs = [
2781-
ax
2782-
for ax in group["axes"]
2783-
if ax is not None and ax.figure is self and ax.get_visible()
2784-
]
2826+
axs = self._visible_subset_group_axes(group)
27852827
if not axs:
27862828
artist.remove()
27872829
del self._subset_title_dict[key]

0 commit comments

Comments
 (0)