Description
When passing a list of Axes to show(), the user is required to also pass fig=. This is unnecessary because the figure can always be inferred from the axes via ax[0].get_figure().
Current behavior:
fig, axes = plt.subplots(1, 2)
sdata.pl.render_shapes().pl.show(ax=list(axes))
# → ValueError: Invalid value of `fig`: None. If a list of `Axes` is passed, a `Figure` must also be specified.
Expected behavior:
fig, axes = plt.subplots(1, 2)
sdata.pl.render_shapes().pl.show(ax=list(axes)) # should work, infer fig from axes
For reference, scanpy's plotting functions accept ax without requiring fig.
Location: src/spatialdata_plot/pl/utils.py, _prepare_params_plot(), lines 266-269.