Skip to content

Commit 941bbdb

Browse files
committed
ENH: delegate broadcast_shapes
1 parent bc126fa commit 941bbdb

4 files changed

Lines changed: 104 additions & 42 deletions

File tree

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._delegation import (
44
argpartition,
55
atleast_nd,
6+
broadcast_shapes,
67
cov,
78
create_diagonal,
89
expand_dims,
@@ -20,7 +21,6 @@
2021
from ._lib._at import at
2122
from ._lib._funcs import (
2223
apply_where,
23-
broadcast_shapes,
2424
default_dtype,
2525
kron,
2626
nunique,

src/array_api_extra/_delegation.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Sequence
44
from types import ModuleType
5-
from typing import Literal
5+
from typing import Literal, cast
66

77
from ._lib import _funcs
88
from ._lib._utils._compat import (
@@ -20,6 +20,7 @@
2020

2121
__all__ = [
2222
"atleast_nd",
23+
"broadcast_shapes",
2324
"cov",
2425
"create_diagonal",
2526
"expand_dims",
@@ -81,6 +82,68 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
8182
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
8283

8384

85+
def broadcast_shapes(
86+
*shapes: tuple[float | None, ...], xp: ModuleType | None = None
87+
) -> tuple[int | None, ...]:
88+
"""
89+
Compute the shape of the broadcasted arrays.
90+
91+
Duplicates :func:`numpy.broadcast_shapes`, with additional support for
92+
None and NaN sizes.
93+
94+
This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
95+
without needing to worry about the backend potentially deep copying
96+
the arrays.
97+
98+
Parameters
99+
----------
100+
*shapes : tuple[int | None, ...]
101+
Shapes of the arrays to broadcast.
102+
xp : array_namespace, optional
103+
The standard-compatible namespace to use for native delegation.
104+
Default: use the array-agnostic implementation.
105+
106+
Returns
107+
-------
108+
tuple[int | None, ...]
109+
The shape of the broadcasted arrays.
110+
111+
See Also
112+
--------
113+
numpy.broadcast_shapes : Equivalent NumPy function.
114+
array_api.broadcast_arrays : Function to broadcast actual arrays.
115+
116+
Notes
117+
-----
118+
This function accepts the Array API's ``None`` for unknown sizes,
119+
as well as Dask's non-standard ``math.nan``.
120+
Regardless of input, the output always contains ``None`` for unknown sizes.
121+
122+
Examples
123+
--------
124+
>>> import array_api_extra as xpx
125+
>>> xpx.broadcast_shapes((2, 3), (2, 1))
126+
(2, 3)
127+
>>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3))
128+
(4, 2, 3)
129+
"""
130+
if (
131+
xp is not None
132+
and all(isinstance(size, int) for shape in shapes for size in shape)
133+
and (
134+
is_numpy_namespace(xp)
135+
or is_cupy_namespace(xp)
136+
or is_dask_namespace(xp)
137+
or is_jax_namespace(xp)
138+
or is_torch_namespace(xp)
139+
)
140+
):
141+
int_shapes = cast(tuple[tuple[int, ...], ...], shapes)
142+
return cast(tuple[int | None, ...], xp.broadcast_shapes(*int_shapes))
143+
144+
return _funcs.broadcast_shapes(*shapes)
145+
146+
84147
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
85148
"""
86149
Estimate a covariance matrix (or a stack of covariance matrices).

src/array_api_extra/_lib/_funcs.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -220,46 +220,10 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
220220

221221
# `float` in signature to accept `math.nan` for Dask.
222222
# `int`s are still accepted as `float` is a superclass of `int` in typing
223-
def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]:
224-
"""
225-
Compute the shape of the broadcasted arrays.
226-
227-
Duplicates :func:`numpy.broadcast_shapes`, with additional support for
228-
None and NaN sizes.
229-
230-
This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
231-
without needing to worry about the backend potentially deep copying
232-
the arrays.
233-
234-
Parameters
235-
----------
236-
*shapes : tuple[int | None, ...]
237-
Shapes of the arrays to broadcast.
238-
239-
Returns
240-
-------
241-
tuple[int | None, ...]
242-
The shape of the broadcasted arrays.
243-
244-
See Also
245-
--------
246-
numpy.broadcast_shapes : Equivalent NumPy function.
247-
array_api.broadcast_arrays : Function to broadcast actual arrays.
248-
249-
Notes
250-
-----
251-
This function accepts the Array API's ``None`` for unknown sizes,
252-
as well as Dask's non-standard ``math.nan``.
253-
Regardless of input, the output always contains ``None`` for unknown sizes.
254-
255-
Examples
256-
--------
257-
>>> import array_api_extra as xpx
258-
>>> xpx.broadcast_shapes((2, 3), (2, 1))
259-
(2, 3)
260-
>>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3))
261-
(4, 2, 3)
262-
"""
223+
def broadcast_shapes( # numpydoc ignore=PR01,RT01
224+
*shapes: tuple[float | None, ...],
225+
) -> tuple[int | None, ...]:
226+
"""See docstring in array_api_extra._delegation."""
263227
if not shapes:
264228
return () # Match NumPy output
265229

tests/test_funcs.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,41 @@ def test_5D_values(self, xp: ModuleType):
489489

490490

491491
class TestBroadcastShapes:
492+
def test_delegates_known_integer_shapes(self, monkeypatch: pytest.MonkeyPatch):
493+
calls = []
494+
495+
def mock_broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]:
496+
calls.append(shapes)
497+
return (99,)
498+
499+
monkeypatch.setattr(np, "broadcast_shapes", mock_broadcast_shapes)
500+
501+
assert broadcast_shapes((2,), (1,), xp=np) == (99,)
502+
assert calls == [((2,), (1,))]
503+
504+
def test_fallback_for_unknown_sizes(self, monkeypatch: pytest.MonkeyPatch):
505+
def mock_broadcast_shapes(*_shapes: tuple[int, ...]) -> tuple[int, ...]:
506+
msg = "Native delegation should not handle unknown sizes"
507+
raise AssertionError(msg)
508+
509+
monkeypatch.setattr(np, "broadcast_shapes", mock_broadcast_shapes)
510+
511+
assert broadcast_shapes((None,), (1,), xp=np) == (None,)
512+
assert broadcast_shapes((math.nan,), (1,), xp=np) == (None,)
513+
514+
def test_fallback_without_xp(self, monkeypatch: pytest.MonkeyPatch):
515+
def mock_broadcast_shapes(*_shapes: tuple[int, ...]) -> tuple[int, ...]:
516+
msg = "Native delegation should not be used without xp"
517+
raise AssertionError(msg)
518+
519+
monkeypatch.setattr(np, "broadcast_shapes", mock_broadcast_shapes)
520+
521+
assert broadcast_shapes((2,), (1,)) == (2,)
522+
523+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
524+
def test_xp(self, xp: ModuleType):
525+
assert broadcast_shapes((2, 3), (2, 1), xp=xp) == (2, 3)
526+
492527
@pytest.mark.parametrize(
493528
"args",
494529
[

0 commit comments

Comments
 (0)