diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index d2a57c86..26afd31e 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -228,8 +228,8 @@ def xp_assert_close( actual: Array, desired: Array, *, - rtol: float | None = None, - atol: float = 0, + rtol: float | Array | None = None, + atol: float | Array = 0, equal_nan: bool = True, err_msg: str = "", verbose: bool = True, @@ -246,9 +246,9 @@ def xp_assert_close( The array produced by the tested function. desired : Array The expected array (typically hardcoded). - rtol : float, optional + rtol : float or Array, optional Relative tolerance. Default: dtype-dependent. - atol : float, optional + atol : float or Array, optional Absolute tolerance. Default: 0. equal_nan : bool, default: True Whether to consider NaNs in corresponding locations as equal. @@ -271,6 +271,8 @@ def xp_assert_close( Notes ----- The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`. + + Array arguments to `atol` and `rtol` must be valid input to :py:func:`float`. """ actual, desired, xp = _check_ns_shape_dtype( actual, desired, check_dtype, check_shape, check_scalar @@ -286,13 +288,17 @@ def xp_assert_close( rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 else: rtol = 1e-7 + else: + rtol = float(rtol) + + atol = float(atol) actual_np = as_numpy_array(actual, xp=xp) desired_np = as_numpy_array(desired, xp=xp) np.testing.assert_allclose( # pyright: ignore[reportCallIssue] actual_np, desired_np, - rtol=rtol, # pyright: ignore[reportArgumentType] + rtol=rtol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] atol=atol, equal_nan=equal_nan, err_msg=err_msg,