From 4a65b5d11c6e0fcf9c59da2ea5110191b97aebf9 Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Date: Sun, 17 May 2026 02:50:54 +0530 Subject: [PATCH 1/6] Adding case for non float rtol and atol inputs Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> --- src/array_api_extra/_lib/_testing.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index d2a57c86..f15481f8 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -228,8 +228,8 @@ def xp_assert_close( actual: Array, desired: Array, *, - rtol: float | None = None, - atol: float = 0, + rtol: float | Array | None = None, + atol: float | Array = 0, equal_nan: bool = True, err_msg: str = "", verbose: bool = True, @@ -246,9 +246,9 @@ def xp_assert_close( The array produced by the tested function. desired : Array The expected array (typically hardcoded). - rtol : float, optional + rtol : float or Array, optional Relative tolerance. Default: dtype-dependent. - atol : float, optional + atol : float or Array, optional Absolute tolerance. Default: 0. equal_nan : bool, default: True Whether to consider NaNs in corresponding locations as equal. @@ -287,6 +287,18 @@ def xp_assert_close( else: rtol = 1e-7 + if not isinstance(atol, float): + atol = as_numpy_array(atol, xp=xp) + if atol.ndim > 0: + msg = "atol must be a scalar or 0-D array" + raise TypeError(msg) + + if not isinstance(rtol, float): + rtol = as_numpy_array(rtol, xp=xp) + if rtol.ndim > 0: + msg = "rtol must be a scalar or 0-D array" + raise TypeError(msg) + actual_np = as_numpy_array(actual, xp=xp) desired_np = as_numpy_array(desired, xp=xp) np.testing.assert_allclose( # pyright: ignore[reportCallIssue] From 941bd078fa4ca7d5caf056a389c71f800dad2814 Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Date: Sun, 17 May 2026 03:22:42 +0530 Subject: [PATCH 2/6] checking for ndim attribute instead Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> --- src/array_api_extra/_lib/_testing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index f15481f8..fd2e0e89 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -287,13 +287,13 @@ def xp_assert_close( else: rtol = 1e-7 - if not isinstance(atol, float): + if hasattr(atol, "ndim"): atol = as_numpy_array(atol, xp=xp) if atol.ndim > 0: msg = "atol must be a scalar or 0-D array" raise TypeError(msg) - if not isinstance(rtol, float): + if hasattr(rtol, "ndim"): rtol = as_numpy_array(rtol, xp=xp) if rtol.ndim > 0: msg = "rtol must be a scalar or 0-D array" From 89dd532112100e81c4ec851d20aca6f2c1c69c12 Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Date: Sun, 17 May 2026 03:41:11 +0530 Subject: [PATCH 3/6] changing condition Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> --- src/array_api_extra/_lib/_testing.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index fd2e0e89..1b514ac5 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -288,16 +288,12 @@ def xp_assert_close( rtol = 1e-7 if hasattr(atol, "ndim"): - atol = as_numpy_array(atol, xp=xp) - if atol.ndim > 0: - msg = "atol must be a scalar or 0-D array" - raise TypeError(msg) + if atol.ndim == 0: + atol = as_numpy_array(atol, xp=xp) if hasattr(rtol, "ndim"): - rtol = as_numpy_array(rtol, xp=xp) - if rtol.ndim > 0: - msg = "rtol must be a scalar or 0-D array" - raise TypeError(msg) + if rtol.ndim == 0: + rtol = as_numpy_array(rtol, xp=xp) actual_np = as_numpy_array(actual, xp=xp) desired_np = as_numpy_array(desired, xp=xp) From 3848cf2523d2bd1e0a18a5d2d78aea21a810b4eb Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sat, 16 May 2026 23:17:11 +0100 Subject: [PATCH 4/6] typing --- src/array_api_extra/_lib/_testing.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 1b514ac5..253f129a 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -287,21 +287,20 @@ def xp_assert_close( else: rtol = 1e-7 - if hasattr(atol, "ndim"): - if atol.ndim == 0: - atol = as_numpy_array(atol, xp=xp) + if hasattr(atol, "ndim") and atol.ndim == 0: # pyright: ignore[reportAttributeAccessIssue] + atol = cast(Array, as_numpy_array(cast(Array, atol), xp=xp)) # pyright: ignore[reportInvalidCast] - if hasattr(rtol, "ndim"): - if rtol.ndim == 0: - rtol = as_numpy_array(rtol, xp=xp) + if hasattr(rtol, "ndim") and rtol.ndim == 0: # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + rtol = cast(Array, as_numpy_array(cast(Array, rtol), xp=xp)) # pyright: ignore[reportInvalidCast] actual_np = as_numpy_array(actual, xp=xp) desired_np = as_numpy_array(desired, xp=xp) - np.testing.assert_allclose( # pyright: ignore[reportCallIssue] + np.testing.assert_allclose( # pyright: ignore[reportCallIssue] # pyrefly: ignore[no-matching-overload] actual_np, desired_np, - rtol=rtol, # pyright: ignore[reportArgumentType] - atol=atol, + # https://github.com/numpy/numpy/issues/31449 + rtol=rtol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + atol=atol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] equal_nan=equal_nan, err_msg=err_msg, verbose=verbose, From 9e379315d42c562b122549c8ce87ef050aa1046b Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sat, 16 May 2026 23:57:32 +0100 Subject: [PATCH 5/6] try just `float` --- src/array_api_extra/_lib/_testing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 253f129a..6cd513ff 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -271,6 +271,8 @@ def xp_assert_close( Notes ----- The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`. + + Array arguments to `atol` and `rtol` must be valid input to :py:func:`float`. """ actual, desired, xp = _check_ns_shape_dtype( actual, desired, check_dtype, check_shape, check_scalar @@ -286,21 +288,19 @@ def xp_assert_close( rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 else: rtol = 1e-7 + else: + rtol = float(rtol) - if hasattr(atol, "ndim") and atol.ndim == 0: # pyright: ignore[reportAttributeAccessIssue] - atol = cast(Array, as_numpy_array(cast(Array, atol), xp=xp)) # pyright: ignore[reportInvalidCast] - - if hasattr(rtol, "ndim") and rtol.ndim == 0: # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] - rtol = cast(Array, as_numpy_array(cast(Array, rtol), xp=xp)) # pyright: ignore[reportInvalidCast] + atol = float(atol) actual_np = as_numpy_array(actual, xp=xp) desired_np = as_numpy_array(desired, xp=xp) - np.testing.assert_allclose( # pyright: ignore[reportCallIssue] # pyrefly: ignore[no-matching-overload] + np.testing.assert_allclose( # pyright: ignore[reportCallIssue] actual_np, desired_np, # https://github.com/numpy/numpy/issues/31449 rtol=rtol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - atol=atol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + atol=atol, equal_nan=equal_nan, err_msg=err_msg, verbose=verbose, From c7e31ddf317a97841c01bd765246e570cf41fd1f Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+prady0t@users.noreply.github.com> Date: Sun, 17 May 2026 11:56:36 +0530 Subject: [PATCH 6/6] Update src/array_api_extra/_lib/_testing.py Co-authored-by: Lucas Colley --- src/array_api_extra/_lib/_testing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 6cd513ff..26afd31e 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -298,7 +298,6 @@ def xp_assert_close( np.testing.assert_allclose( # pyright: ignore[reportCallIssue] actual_np, desired_np, - # https://github.com/numpy/numpy/issues/31449 rtol=rtol, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] atol=atol, equal_nan=equal_nan,