Skip to content

Commit bcba9e9

Browse files
committed
ENH: Add angle function
1 parent ba2ea5e commit bcba9e9

5 files changed

Lines changed: 114 additions & 1 deletion

File tree

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
angle
910
apply_where
1011
argpartition
1112
at

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Extra array functions built on top of the array API standard."""
22

33
from ._delegation import (
4+
angle,
45
argpartition,
56
atleast_nd,
67
cov,
@@ -32,6 +33,7 @@
3233
# pylint: disable=duplicate-code
3334
__all__ = [
3435
"__version__",
36+
"angle",
3537
"apply_where",
3638
"argpartition",
3739
"at",

src/array_api_extra/_delegation.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ._lib._utils._typing import Array, DType
2020

2121
__all__ = [
22+
"angle",
2223
"atleast_nd",
2324
"cov",
2425
"create_diagonal",
@@ -32,6 +33,42 @@
3233
]
3334

3435

36+
def angle(z: Array, /, *, deg: bool = False, xp: ModuleType | None = None) -> Array:
37+
"""
38+
Return the angle of a complex argument.
39+
40+
Parameters
41+
----------
42+
z : array
43+
A complex-valued or real-valued array.
44+
deg : bool, optional
45+
Return angle in degrees if True, radians if False (default).
46+
xp : array_namespace, optional
47+
The standard-compatible namespace for `z`. Default: infer.
48+
49+
Returns
50+
-------
51+
array
52+
The counterclockwise angle from the positive real axis on the complex
53+
plane in the range ``(-pi, pi]``.
54+
55+
Notes
56+
-----
57+
This function passes the imaginary and real parts of the argument to
58+
``xp.atan2`` to compute the result.
59+
60+
Examples
61+
--------
62+
>>> import array_api_strict as xp
63+
>>> import array_api_extra as xpx
64+
>>> xpx.angle(xp.asarray([1.0, 1.0j, 1 + 1j]), xp=xp)
65+
Array([0. , 1.57079633, 0.78539816], dtype=array_api_strict.float64)
66+
"""
67+
if xp is None:
68+
xp = array_namespace(z)
69+
return _funcs.angle(z, deg=deg, xp=xp)
70+
71+
3572
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
3673
"""
3774
Recursively expand the dimension of an array to at least `ndim`.

src/array_api_extra/_lib/_funcs.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ._utils._typing import Array, Device, DType
2424

2525
__all__ = [
26+
"angle",
2627
"apply_where",
2728
"atleast_nd",
2829
"broadcast_shapes",
@@ -38,6 +39,24 @@
3839
]
3940

4041

42+
def angle(z: Array, /, *, deg: bool = False, xp: ModuleType) -> Array:
43+
# numpydoc ignore=PR01,RT01
44+
"""See docstring in array_api_extra._delegation."""
45+
if xp.isdtype(z.dtype, "complex floating"):
46+
zimag = xp.imag(z)
47+
zreal = xp.real(z)
48+
else:
49+
if not xp.isdtype(z.dtype, "real floating"):
50+
z = xp.astype(z, default_dtype(xp, device=_compat.device(z)))
51+
zimag = xp.zeros_like(z)
52+
zreal = z
53+
54+
a = xp.atan2(zimag, zreal)
55+
if deg:
56+
a *= 180 / xp.pi
57+
return a
58+
59+
4160
@overload
4261
def apply_where( # numpydoc ignore=GL08
4362
cond: Array,
@@ -50,7 +69,6 @@ def apply_where( # numpydoc ignore=GL08
5069
xp: ModuleType | None = None,
5170
) -> Array: ...
5271

53-
5472
@overload
5573
def apply_where( # numpydoc ignore=GL08
5674
cond: Array,

tests/test_funcs.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing_extensions import override
1414

1515
from array_api_extra import (
16+
angle,
1617
apply_where,
1718
argpartition,
1819
at,
@@ -49,6 +50,7 @@
4950
from array_api_extra._lib._utils._typing import Array, Device
5051
from array_api_extra.testing import lazy_xp_function
5152

53+
lazy_xp_function(angle)
5254
lazy_xp_function(apply_where)
5355
lazy_xp_function(argpartition)
5456
lazy_xp_function(atleast_nd)
@@ -73,6 +75,59 @@
7375
lazy_xp_function(_funcs_searchsorted)
7476

7577

78+
class TestAngle:
79+
def test_basic(self, xp: ModuleType):
80+
x = xp.asarray(
81+
[
82+
1 + 3j,
83+
np.sqrt(2) / 2.0 + 1j * np.sqrt(2) / 2,
84+
1,
85+
1j,
86+
-1,
87+
-1j,
88+
1 - 3j,
89+
-1 + 3j,
90+
],
91+
dtype=xp.complex128,
92+
)
93+
expected = xp.asarray(
94+
[
95+
np.arctan(3.0 / 1.0),
96+
np.arctan(1.0),
97+
0,
98+
np.pi / 2,
99+
np.pi,
100+
-np.pi / 2.0,
101+
-np.arctan(3.0 / 1.0),
102+
np.pi - np.arctan(3.0 / 1.0),
103+
],
104+
dtype=xp.float64,
105+
)
106+
107+
xp_assert_close(angle(x), expected, rtol=0, atol=1e-11)
108+
xp_assert_close(angle(x, deg=True), expected * 180 / xp.pi, rtol=0, atol=1e-11)
109+
110+
def test_real(self, xp: ModuleType):
111+
x = xp.asarray([0.0, -0.0, 1.0, -1.0])
112+
expected = xp.asarray([0.0, xp.pi, 0.0, xp.pi], dtype=x.dtype)
113+
xp_assert_close(angle(x), expected)
114+
115+
def test_integral(self, xp: ModuleType):
116+
x = xp.asarray([0, -1, 1], dtype=xp.int32)
117+
actual = angle(x)
118+
expected = xp.asarray(
119+
[0.0, xp.pi, 0.0], dtype=default_dtype(xp, device=get_device(x))
120+
)
121+
xp_assert_close(actual, expected)
122+
assert actual.dtype == expected.dtype
123+
124+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
125+
def test_xp(self, xp: ModuleType):
126+
x = xp.asarray([1.0, 1.0j, 1 + 1j], dtype=xp.complex128)
127+
expected = xp.asarray([0.0, xp.pi / 2, xp.pi / 4], dtype=xp.float64)
128+
xp_assert_close(angle(x, xp=xp), expected)
129+
130+
76131
class TestApplyWhere:
77132
@staticmethod
78133
def f1(x: Array, y: Array | int = 10) -> Array:

0 commit comments

Comments
 (0)