|
2 | 2 |
|
3 | 3 | from collections.abc import Sequence |
4 | 4 | from types import ModuleType |
5 | | -from typing import Literal |
| 5 | +from typing import Literal, cast |
6 | 6 |
|
7 | 7 | from ._lib import _funcs |
8 | 8 | from ._lib._utils._compat import ( |
|
20 | 20 |
|
21 | 21 | __all__ = [ |
22 | 22 | "atleast_nd", |
| 23 | + "broadcast_shapes", |
23 | 24 | "cov", |
24 | 25 | "create_diagonal", |
25 | 26 | "expand_dims", |
@@ -81,6 +82,68 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array |
81 | 82 | return _funcs.atleast_nd(x, ndim=ndim, xp=xp) |
82 | 83 |
|
83 | 84 |
|
| 85 | +def broadcast_shapes( |
| 86 | + *shapes: tuple[float | None, ...], xp: ModuleType | None = None |
| 87 | +) -> tuple[int | None, ...]: |
| 88 | + """ |
| 89 | + Compute the shape of the broadcasted arrays. |
| 90 | +
|
| 91 | + Duplicates :func:`numpy.broadcast_shapes`, with additional support for |
| 92 | + None and NaN sizes. |
| 93 | +
|
| 94 | + This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape`` |
| 95 | + without needing to worry about the backend potentially deep copying |
| 96 | + the arrays. |
| 97 | +
|
| 98 | + Parameters |
| 99 | + ---------- |
| 100 | + *shapes : tuple[int | None, ...] |
| 101 | + Shapes of the arrays to broadcast. |
| 102 | + xp : array_namespace, optional |
| 103 | + The standard-compatible namespace to use for native delegation. |
| 104 | + Default: use the array-agnostic implementation. |
| 105 | +
|
| 106 | + Returns |
| 107 | + ------- |
| 108 | + tuple[int | None, ...] |
| 109 | + The shape of the broadcasted arrays. |
| 110 | +
|
| 111 | + See Also |
| 112 | + -------- |
| 113 | + numpy.broadcast_shapes : Equivalent NumPy function. |
| 114 | + array_api.broadcast_arrays : Function to broadcast actual arrays. |
| 115 | +
|
| 116 | + Notes |
| 117 | + ----- |
| 118 | + This function accepts the Array API's ``None`` for unknown sizes, |
| 119 | + as well as Dask's non-standard ``math.nan``. |
| 120 | + Regardless of input, the output always contains ``None`` for unknown sizes. |
| 121 | +
|
| 122 | + Examples |
| 123 | + -------- |
| 124 | + >>> import array_api_extra as xpx |
| 125 | + >>> xpx.broadcast_shapes((2, 3), (2, 1)) |
| 126 | + (2, 3) |
| 127 | + >>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3)) |
| 128 | + (4, 2, 3) |
| 129 | + """ |
| 130 | + if ( |
| 131 | + xp is not None |
| 132 | + and all(isinstance(size, int) for shape in shapes for size in shape) |
| 133 | + and ( |
| 134 | + is_numpy_namespace(xp) |
| 135 | + or is_cupy_namespace(xp) |
| 136 | + or is_dask_namespace(xp) |
| 137 | + or is_jax_namespace(xp) |
| 138 | + or is_torch_namespace(xp) |
| 139 | + ) |
| 140 | + ): |
| 141 | + int_shapes = cast(tuple[tuple[int, ...], ...], shapes) |
| 142 | + return cast(tuple[int | None, ...], xp.broadcast_shapes(*int_shapes)) |
| 143 | + |
| 144 | + return _funcs.broadcast_shapes(*shapes) |
| 145 | + |
| 146 | + |
84 | 147 | def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: |
85 | 148 | """ |
86 | 149 | Estimate a covariance matrix (or a stack of covariance matrices). |
|
0 commit comments