diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index c5b1a307..60fd8fc8 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -136,6 +136,13 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: cpu = cast(Device, jax.devices("cpu")[0]) array = to_device(array, cpu) + # Try DLPack (works for JAX and other backends) + if hasattr(array, "__dlpack__"): + try: + return np.from_dlpack(array) # pyright: ignore[reportArgumentType] + except (TypeError, BufferError): + pass + return np.asarray(array)