|
21 | 21 | __all__ = [ |
22 | 22 | "atleast_nd", |
23 | 23 | "cov", |
| 24 | + "create_diagonal", |
24 | 25 | "expand_dims", |
25 | 26 | "isclose", |
26 | 27 | "nan_to_num", |
@@ -174,6 +175,67 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: |
174 | 175 | return _funcs.cov(m, xp=xp) |
175 | 176 |
|
176 | 177 |
|
| 178 | +def create_diagonal( |
| 179 | + x: Array, /, *, offset: int = 0, xp: ModuleType | None = None |
| 180 | +) -> Array: |
| 181 | + """ |
| 182 | + Construct a diagonal array. |
| 183 | +
|
| 184 | + Parameters |
| 185 | + ---------- |
| 186 | + x : array |
| 187 | + An array having shape ``(*batch_dims, k)``. |
| 188 | + offset : int, optional |
| 189 | + Offset from the leading diagonal (default is ``0``). |
| 190 | + Use positive ints for diagonals above the leading diagonal, |
| 191 | + and negative ints for diagonals below the leading diagonal. |
| 192 | + xp : array_namespace, optional |
| 193 | + The standard-compatible namespace for `x`. Default: infer. |
| 194 | +
|
| 195 | + Returns |
| 196 | + ------- |
| 197 | + array |
| 198 | + An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x` |
| 199 | + on the diagonal (offset by `offset`). |
| 200 | +
|
| 201 | + Examples |
| 202 | + -------- |
| 203 | + >>> import array_api_strict as xp |
| 204 | + >>> import array_api_extra as xpx |
| 205 | + >>> x = xp.asarray([2, 4, 8]) |
| 206 | +
|
| 207 | + >>> xpx.create_diagonal(x, xp=xp) |
| 208 | + Array([[2, 0, 0], |
| 209 | + [0, 4, 0], |
| 210 | + [0, 0, 8]], dtype=array_api_strict.int64) |
| 211 | +
|
| 212 | + >>> xpx.create_diagonal(x, offset=-2, xp=xp) |
| 213 | + Array([[0, 0, 0, 0, 0], |
| 214 | + [0, 0, 0, 0, 0], |
| 215 | + [2, 0, 0, 0, 0], |
| 216 | + [0, 4, 0, 0, 0], |
| 217 | + [0, 0, 8, 0, 0]], dtype=array_api_strict.int64) |
| 218 | + """ |
| 219 | + if xp is None: |
| 220 | + xp = array_namespace(x) |
| 221 | + |
| 222 | + if x.ndim == 0: |
| 223 | + err_msg = "`x` must be at least 1-dimensional." |
| 224 | + raise ValueError(err_msg) |
| 225 | + |
| 226 | + if is_torch_namespace(xp): |
| 227 | + return xp.diag_embed(x, offset=offset, dim1=-2, dim2=-1) |
| 228 | + |
| 229 | + if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2: |
| 230 | + return xp.diag(x, k=offset) |
| 231 | + |
| 232 | + if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3: |
| 233 | + batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset) |
| 234 | + return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n)) |
| 235 | + |
| 236 | + return _funcs.create_diagonal(x, offset=offset, xp=xp) |
| 237 | + |
| 238 | + |
177 | 239 | def expand_dims( |
178 | 240 | a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None |
179 | 241 | ) -> Array: |
|
0 commit comments