Skip to content

Commit 97feb21

Browse files
authored
Fix: axes aspect shifting on pixel snapping after drawn (#680)
* closes #679 by restoring alignment for mixed-span subplot layouts when one of the axes uses a fixed aspect ratio. The regression showed up in arrangements like [[1, 2], [1, 3]], where an equal-aspect axis on the left would shrink inside its gridspec slot but the stacked axes on the right would keep their full vertical extent and visibly stick out above and below it. This change teaches the figure layout pass to propagate the aspect-constrained bounds across the neighboring subplots that share the same span, and adds a regression test for both the legacy and UltraLayout code paths so the layout stays visually consistent going forward.
1 parent a1bb307 commit 97feb21

3 files changed

Lines changed: 589 additions & 48 deletions

File tree

ultraplot/figure.py

Lines changed: 178 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,109 @@ def _snap_axes_to_pixel_grid(self, renderer) -> None:
11641164
which="both",
11651165
)
11661166

1167+
def _find_misaligned_spans(
1168+
self, axes: List[paxes.Axes], *, tol: float = 1e-9
1169+
) -> List[Tuple[str, int, int, mtransforms.Bbox, mtransforms.Bbox, paxes.Axes]]:
1170+
"""
1171+
Identify spanning axes whose actual position differs from their
1172+
gridspec slot (e.g. because of an aspect constraint).
1173+
1174+
Returns a list of ``(axis, start, stop, slot, pos, ref_ax)`` tuples
1175+
where *axis* is ``'y'`` for row-spanning or ``'x'`` for column-spanning.
1176+
"""
1177+
spans = []
1178+
for ax in axes:
1179+
try:
1180+
ax.apply_aspect()
1181+
ss = ax.get_subplotspec().get_topmost_subplotspec()
1182+
row1, row2, col1, col2 = ss._get_rows_columns()
1183+
slot = ss.get_position(self)
1184+
pos = ax.get_position(original=False)
1185+
except (AttributeError, TypeError):
1186+
continue
1187+
1188+
if row2 > row1 and (
1189+
abs(pos.y0 - slot.y0) > tol
1190+
or abs((pos.y0 + pos.height) - (slot.y0 + slot.height)) > tol
1191+
):
1192+
spans.append(("y", row1, row2, slot, pos, ax))
1193+
if col2 > col1 and (
1194+
abs(pos.x0 - slot.x0) > tol
1195+
or abs((pos.x0 + pos.width) - (slot.x0 + slot.width)) > tol
1196+
):
1197+
spans.append(("x", col1, col2, slot, pos, ax))
1198+
return spans
1199+
1200+
def _remap_axes_to_span(
1201+
self,
1202+
axes: List[paxes.Axes],
1203+
spans: List[
1204+
Tuple[str, int, int, mtransforms.Bbox, mtransforms.Bbox, paxes.Axes]
1205+
],
1206+
*,
1207+
tol: float = 1e-9,
1208+
) -> None:
1209+
"""
1210+
Remap sibling axes so they align with the actual bounds of
1211+
spanning axes described by *spans*. Siblings with their own
1212+
fixed aspect are skipped since they have independent constraints.
1213+
"""
1214+
for axis, start, stop, slot, pos, ref_ax in spans:
1215+
slot0 = slot.y0 if axis == "y" else slot.x0
1216+
slotsize = slot.height if axis == "y" else slot.width
1217+
pos0 = pos.y0 if axis == "y" else pos.x0
1218+
possize = pos.height if axis == "y" else pos.width
1219+
if slotsize <= tol or possize <= tol:
1220+
continue
1221+
1222+
for ax in axes:
1223+
if ax is ref_ax:
1224+
continue
1225+
try:
1226+
if ax.get_aspect() != "auto":
1227+
continue
1228+
ss = ax.get_subplotspec().get_topmost_subplotspec()
1229+
row1, row2, col1, col2 = ss._get_rows_columns()
1230+
if axis == "y":
1231+
if row1 < start or row2 > stop:
1232+
continue
1233+
else:
1234+
if col1 < start or col2 > stop:
1235+
continue
1236+
old = ss.get_position(self)
1237+
except (AttributeError, TypeError):
1238+
continue
1239+
1240+
if axis == "y":
1241+
rel0 = (old.y0 - slot0) / slotsize
1242+
rel1 = (old.y0 + old.height - slot0) / slotsize
1243+
new0 = pos0 + rel0 * possize
1244+
new1 = pos0 + rel1 * possize
1245+
bounds = [old.x0, new0, old.width, new1 - new0]
1246+
else:
1247+
rel0 = (old.x0 - slot0) / slotsize
1248+
rel1 = (old.x0 + old.width - slot0) / slotsize
1249+
new0 = pos0 + rel0 * possize
1250+
new1 = pos0 + rel1 * possize
1251+
bounds = [new0, old.y0, new1 - new0, old.height]
1252+
ax.set_position(bounds, which="both")
1253+
1254+
def _align_spanning_axes(self, *, tol: float = 1e-9) -> None:
1255+
"""
1256+
Align sibling subplots to spanning axes whose actual position
1257+
differs from their gridspec slot.
1258+
1259+
When a subplot spans multiple rows or columns and is shrunk inside
1260+
its slot (e.g. by a fixed aspect ratio), the adjacent subplots keep
1261+
their full extent and visibly stick out. This method detects the
1262+
mismatch and remaps the sibling positions proportionally.
1263+
"""
1264+
axes = list(self._iter_axes(hidden=False, children=False, panels=False))
1265+
if not axes:
1266+
return
1267+
spans = self._find_misaligned_spans(axes, tol=tol)
1268+
self._remap_axes_to_span(axes, spans, tol=tol)
1269+
11671270
def _share_ticklabels(self, *, axis: str) -> None:
11681271
"""
11691272
Tick label sharing is determined at the figure level. While
@@ -2562,6 +2665,61 @@ def _align_super_title(self, renderer):
25622665
y = y_target - y_bbox
25632666
self._suptitle.set_position((x, y))
25642667

2668+
@staticmethod
2669+
def _deduplicate_axes(axes: Iterable[paxes.Axes]) -> List[paxes.Axes]:
2670+
"""
2671+
Resolve panel parents and remove duplicates, preserving order.
2672+
"""
2673+
seen = set()
2674+
unique = []
2675+
for ax in axes:
2676+
ax = ax._panel_parent or ax
2677+
ax_id = id(ax)
2678+
if ax_id not in seen:
2679+
seen.add(ax_id)
2680+
unique.append(ax)
2681+
return unique
2682+
2683+
@staticmethod
2684+
def _normalize_title_alignment(loc: str) -> str:
2685+
"""
2686+
Convert a *loc* string to a horizontal alignment for ``Text.set_ha``.
2687+
"""
2688+
align = _translate_loc(loc, "text")
2689+
match align:
2690+
case "left" | "outer left" | "upper left" | "lower left":
2691+
return "left"
2692+
case "center" | "upper center" | "lower center":
2693+
return "center"
2694+
case "right" | "outer right" | "upper right" | "lower right":
2695+
return "right"
2696+
case _:
2697+
raise ValueError(f"Invalid shared subplot title location {loc!r}.")
2698+
2699+
@staticmethod
2700+
def _resolve_title_props(
2701+
fontdict: dict[str, Any] | None, kwargs: dict[str, Any]
2702+
) -> dict[str, Any]:
2703+
"""
2704+
Build the property dict for a title from rc defaults, *fontdict*,
2705+
and extra *kwargs*.
2706+
"""
2707+
kw = rc.fill(
2708+
{
2709+
"size": "title.size",
2710+
"weight": "title.weight",
2711+
"color": "title.color",
2712+
"family": "font.family",
2713+
},
2714+
context=True,
2715+
)
2716+
if "color" in kw and kw["color"] == "auto":
2717+
del kw["color"]
2718+
if fontdict:
2719+
kw.update(fontdict)
2720+
kw.update(kwargs)
2721+
return kw
2722+
25652723
def _update_subset_title(
25662724
self,
25672725
axes: Iterable[paxes.Axes],
@@ -2594,47 +2752,17 @@ def _update_subset_title(
25942752
if not axes:
25952753
raise ValueError("Need at least one axes to create a shared subplot title.")
25962754

2597-
seen = set()
2598-
unique_axes = []
2599-
for ax in axes:
2600-
ax = ax._panel_parent or ax
2601-
ax_id = id(ax)
2602-
if ax_id in seen:
2603-
continue
2604-
seen.add(ax_id)
2605-
unique_axes.append(ax)
2606-
axes = unique_axes
2755+
axes = self._deduplicate_axes(axes)
26072756
if len(axes) < 2:
26082757
return axes[0].set_title(
26092758
title, fontdict=fontdict, loc=loc, pad=pad, y=y, **kwargs
26102759
)
26112760

26122761
key = tuple(sorted(id(ax) for ax in axes))
26132762
group = self._subset_title_dict.get(key)
2614-
kw = rc.fill(
2615-
{
2616-
"size": "title.size",
2617-
"weight": "title.weight",
2618-
"color": "title.color",
2619-
"family": "font.family",
2620-
},
2621-
context=True,
2622-
)
2623-
if "color" in kw and kw["color"] == "auto":
2624-
del kw["color"]
2625-
if fontdict:
2626-
kw.update(fontdict)
2627-
kw.update(kwargs)
2628-
align = _translate_loc(loc, "text")
2629-
match align:
2630-
case "left" | "outer left" | "upper left" | "lower left":
2631-
align = "left"
2632-
case "center" | "upper center" | "lower center":
2633-
align = "center"
2634-
case "right" | "outer right" | "upper right" | "lower right":
2635-
align = "right"
2636-
case _:
2637-
raise ValueError(f"Invalid shared subplot title location {loc!r}.")
2763+
kw = self._resolve_title_props(fontdict, kwargs)
2764+
align = self._normalize_title_alignment(loc)
2765+
26382766
if group is None:
26392767
artist = self.text(
26402768
0.5,
@@ -2660,6 +2788,16 @@ def _update_subset_title(
26602788
artist.update(kw)
26612789
return artist
26622790

2791+
def _visible_subset_group_axes(self, group: dict[str, Any]) -> List[paxes.Axes]:
2792+
"""
2793+
Return visible axes from a subset-title group that belong to this figure.
2794+
"""
2795+
return [
2796+
ax
2797+
for ax in group["axes"]
2798+
if ax is not None and ax.figure is self and ax.get_visible()
2799+
]
2800+
26632801
def _get_subset_title_bbox(
26642802
self, ax: paxes.Axes, renderer
26652803
) -> mtransforms.Bbox | None:
@@ -2677,32 +2815,22 @@ def _get_subset_title_bbox(
26772815
artist = group["artist"]
26782816
if not artist.get_visible() or not artist.get_text():
26792817
continue
2680-
axs = [
2681-
group_ax._panel_parent or group_ax
2682-
for group_ax in group["axes"]
2683-
if group_ax is not None
2684-
and group_ax.figure is self
2685-
and group_ax.get_visible()
2686-
]
2818+
axs = [a._panel_parent or a for a in self._visible_subset_group_axes(group)]
26872819
if not axs or ax not in axs:
26882820
continue
2689-
top = min(group_ax._range_subplotspec("y")[0] for group_ax in axs)
2821+
top = min(a._range_subplotspec("y")[0] for a in axs)
26902822
if ax._range_subplotspec("y")[0] == top:
26912823
bboxes.append(artist.get_window_extent(renderer))
26922824
return mtransforms.Bbox.union(bboxes) if bboxes else None
26932825

2694-
def _align_subset_titles(self, renderer):
2826+
def _align_subset_titles(self, renderer: Any) -> None:
26952827
"""
26962828
Update the positions of titles spanning subplot subsets.
26972829
"""
26982830
for key in list(self._subset_title_dict):
26992831
group = self._subset_title_dict[key]
27002832
artist = group["artist"]
2701-
axs = [
2702-
ax
2703-
for ax in group["axes"]
2704-
if ax is not None and ax.figure is self and ax.get_visible()
2705-
]
2833+
axs = self._visible_subset_group_axes(group)
27062834
if not axs:
27072835
artist.remove()
27082836
del self._subset_title_dict[key]
@@ -2979,9 +3107,11 @@ def _align_content(): # noqa: E306
29793107
return
29803108
if aspect:
29813109
gs._auto_layout_aspect()
3110+
self._align_spanning_axes()
29823111
_align_content()
29833112
if tight:
29843113
gs._auto_layout_tight(renderer)
3114+
self._align_spanning_axes()
29853115
_align_content()
29863116

29873117
@warnings._rename_kwargs(

0 commit comments

Comments
 (0)