forked from data-apis/array-api-extra
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_testing.py
More file actions
335 lines (291 loc) · 11.3 KB
/
_testing.py
File metadata and controls
335 lines (291 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""
Testing utilities.
Note that this is private API; don't expect it to be stable.
See also ..testing for public testing utilities.
"""
from __future__ import annotations
import math
from types import ModuleType
from typing import Any, cast
import numpy as np
import pytest
from ._utils._compat import (
array_namespace,
is_array_api_strict_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_array,
is_torch_namespace,
to_device,
)
from ._utils._typing import Array, Device
__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"]
def _check_ns_shape_dtype(
actual: Array,
desired: Array,
check_dtype: bool,
check_shape: bool,
check_scalar: bool,
) -> tuple[Array, Array, ModuleType]: # numpydoc ignore=RT03
"""
Assert that namespace, shape and dtype of the two arrays match.
Parameters
----------
actual : Array
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
NumPy only: whether to check agreement between actual and desired types -
0d array vs scalar.
Returns
-------
Actual array, desired array, and their namespace.
"""
actual_xp = array_namespace(actual) # Raises on Python scalars and lists
desired_xp = array_namespace(desired)
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg
if is_numpy_namespace(actual_xp) and check_scalar:
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
_msg = (
"array-ness does not match:\n Actual: "
f"{type(actual)}\n Desired: {type(desired)}"
)
assert np.isscalar(actual) == np.isscalar(desired), _msg
# Dask uses nan instead of None for unknown shapes
actual_shape = cast(tuple[float, ...], actual.shape)
desired_shape = cast(tuple[float, ...], desired.shape)
assert None not in actual_shape # Requires explicit support
assert None not in desired_shape
if is_dask_namespace(desired_xp):
if any(math.isnan(i) for i in actual_shape):
actual.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
actual_shape = cast(tuple[float, ...], actual.shape)
if any(math.isnan(i) for i in desired_shape):
desired.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
desired_shape = cast(tuple[float, ...], desired.shape)
if check_shape:
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
assert actual_shape == desired_shape, msg
elif desired.ndim > 0:
# Ignore shape, but check flattened size. This is normally done by
# np.testing.assert_array_equal etc even when strict=False, but not for
# non-materializable arrays.
# This check excludes 0d arrays as they are special-cased in NumPy.
actual_size = math.prod(actual_shape)
desired_size = math.prod(desired_shape)
msg = f"sizes do not match: {actual_size} != f{desired_size}"
assert actual_size == desired_size, msg
if check_dtype:
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
assert actual.dtype == desired.dtype, msg
desired = desired_xp.broadcast_to(desired, actual_shape)
return actual, desired, desired_xp
def _is_materializable(x: Array) -> bool:
"""
Return True if you can call `as_numpy_array(x)`; False otherwise.
"""
# Important: here we assume that we're not tracing -
# e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.
return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
"""
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
"""
if is_cupy_namespace(xp):
return xp.asnumpy(array)
if is_pydata_sparse_namespace(xp):
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
if is_torch_namespace(xp):
array.resolve_conj() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
array = to_device(array, "cpu")
if is_array_api_strict_namespace(xp):
cpu: Device = xp.Device("CPU_DEVICE")
array = to_device(array, cpu)
if is_jax_namespace(xp):
import jax
# Note: only needed if the transfer guard is enabled
cpu = cast(Device, jax.devices("cpu")[0])
array = to_device(array, cpu)
return np.asarray(array)
def xp_assert_equal(
actual: Array,
desired: Array,
*,
err_msg: str = "",
verbose: bool = True,
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
) -> None:
"""
Array-API compatible version of `np.testing.assert_array_equal`.
Parameters
----------
actual : Array
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).
err_msg : str, optional
Error message to display on failure.
verbose: bool, default: True
Whether to include the conflicting arrays in the error message on failure.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
NumPy only: whether to check agreement between actual and desired types -
0d array vs scalar.
See Also
--------
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
actual, desired, xp = _check_ns_shape_dtype(
actual, desired, check_dtype, check_shape, check_scalar
)
if not _is_materializable(actual):
return
actual_np = as_numpy_array(actual, xp=xp)
desired_np = as_numpy_array(desired, xp=xp)
np.testing.assert_array_equal(
actual_np, desired_np, err_msg=err_msg, verbose=verbose
)
def xp_assert_less(
x: Array,
y: Array,
*,
err_msg: str = "",
verbose: bool = True,
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
) -> None:
"""
Array-API compatible version of `np.testing.assert_array_less`.
Parameters
----------
x, y : Array
The arrays to compare according to ``x < y`` (elementwise).
err_msg : str, optional
Error message to display on failure.
verbose: bool, default: True
Whether to include the conflicting arrays in the error message on failure.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
NumPy only: whether to check agreement between actual and desired types -
0d array vs scalar.
See Also
--------
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
if not _is_materializable(x):
return
x_np = as_numpy_array(x, xp=xp)
y_np = as_numpy_array(y, xp=xp)
np.testing.assert_array_less(x_np, y_np, err_msg=err_msg, verbose=verbose)
def xp_assert_close(
actual: Array,
desired: Array,
*,
rtol: float | Array | None = None,
atol: float | Array = 0,
equal_nan: bool = True,
err_msg: str = "",
verbose: bool = True,
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
) -> None:
"""
Array-API compatible version of `np.testing.assert_allclose`.
Parameters
----------
actual : Array
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).
rtol : float or Array, optional
Relative tolerance. Default: dtype-dependent.
atol : float or Array, optional
Absolute tolerance. Default: 0.
equal_nan : bool, default: True
Whether to consider NaNs in corresponding locations as equal.
err_msg : str, optional
Error message to display on failure.
verbose: bool, default: True
Whether to include the conflicting arrays in the error message on failure.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
NumPy only: whether to check agreement between actual and desired types -
0d array vs scalar.
See Also
--------
xp_assert_equal : Similar function for exact equality checks.
isclose : Public function for checking closeness.
numpy.testing.assert_allclose : Similar function for NumPy arrays.
Notes
-----
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
Array arguments to `atol` and `rtol` must be valid input to :py:func:`float`.
"""
actual, desired, xp = _check_ns_shape_dtype(
actual, desired, check_dtype, check_shape, check_scalar
)
if not _is_materializable(actual):
return
if rtol is None:
if xp.isdtype(actual.dtype, ("real floating", "complex floating")):
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
# roughly half way between sqrt(eps) and the default for
# `numpy.testing.assert_allclose`, 1e-7
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
else:
rtol = 1e-7
else:
rtol = float(rtol)
atol = float(atol)
actual_np = as_numpy_array(actual, xp=xp)
desired_np = as_numpy_array(desired, xp=xp)
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
actual_np,
desired_np,
rtol=rtol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
atol=atol,
equal_nan=equal_nan,
err_msg=err_msg,
verbose=verbose,
)
def xfail(
request: pytest.FixtureRequest, *, reason: str, strict: bool | None = None
) -> None:
"""
XFAIL the currently running test.
Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
halting it, so that it may result in a XPASS.
xref https://github.com/pandas-dev/pandas/issues/38902
Parameters
----------
request : pytest.FixtureRequest
``request`` argument of the test function.
reason : str
Reason for the expected failure.
strict: bool, optional
If True, the test will be marked as failed if it passes.
If False, the test will be marked as passed if it fails.
Default: ``xfail_strict`` value in ``pyproject.toml``, or False if absent.
"""
if strict is not None:
marker = pytest.mark.xfail(reason=reason, strict=strict)
else:
marker = pytest.mark.xfail(reason=reason)
request.node.add_marker(marker)