Skip to content

ENH: Add basic dlpack testing#433

Open
ev-br wants to merge 6 commits intodata-apis:masterfrom
ev-br:test_dlpack
Open

ENH: Add basic dlpack testing#433
ev-br wants to merge 6 commits intodata-apis:masterfrom
ev-br:test_dlpack

Conversation

@ev-br
Copy link
Member

@ev-br ev-br commented Mar 23, 2026

Add a basic set of dlpack tests. The tests here are "boring":

  • no cross-library tests, only test from_dlpack for a single array library
  • only test copy=False on the same device
  • no tests for __dlpack__ yet several XXX comments to mark where I've no idea.

@ev-br
Copy link
Member Author

ev-br commented Mar 23, 2026

Issues with these "trivial" tests:

  1. torch has a weird device(type="meta") which is incompatible with DLpack. If anything, the following should raise a BufferError IIUC, not a ValueError
In [8]: device = torch.device(type="meta")

In [9]: x = torch.empty(3, device=device)

In [10]: torch.from_dlpack(x)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[10], line 1
----> 1 torch.from_dlpack(x)
...
ValueError: Unknown device type meta for Dlpack
  1. jax fails with copy=True on CUDA where it works on CPU:
In [8]: jnp.from_dlpack(jnp.ones(3), copy=True)
---------------------------------------------------------------------------
JaxRuntimeError                           Traceback (most recent call last)
Cell In[8], line 1
----> 1 jnp.from_dlpack(jnp.ones(3), copy=True)

...

JaxRuntimeError: UNIMPLEMENTED: PJRT C API does not support HostBufferSemantics other than HostBufferSemantics::kImmutableOnlyDuringCall, HostBufferSemantics::kImmutableZeroCopy and HostBufferSemantics::kImmutableUntilTransferCompletes

  1. cupy==14.0.1 doesn't implement copy or device arguments:
>>> x =  cupy.array([], dtype=float)
>>> cupy.from_dlpack(x, copy=True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "cupy/_core/dlpack.pyx", line 568, in cupy._core.dlpack.from_dlpack
  File "cupy/_core/core.pyx", line 355, in cupy._core.core._ndarray_base.__dlpack__
BufferError: copy=True only supported for copy to CPU.

>>> cupy.from_dlpack(x, device=x.device)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "cupy/_core/dlpack.pyx", line 554, in cupy._core.dlpack.from_dlpack
NotImplementedError: from_dlpack() does not support device yet.

>>> cupy.__version__
'14.0.1'
  1. numpy cannot consume a capsule created by __dlpack__, while torch happily does. Which one is correct?
>>> import numpy as np
>>> import torch
>>> x = np.ones(3)
>>> capsule = x.__dlpack__()
>>> torch.from_dlpack(capsule)             # works fine
tensor([1., 1., 1.], dtype=torch.float64)
>>> np.from_dlpack(capsule)                 # ouch
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'PyCapsule' object has no attribute '__dlpack__'

@ev-br
Copy link
Member Author

ev-br commented Mar 23, 2026

Larger issues:

  1. How to tell if a device is dlpack-compatible. The only way I see now is to try-except call from_dlpack and catch a BufferError.

  2. How to tell if a copy=False device transfer is meant to work, given source_device and target_device. A first reaction is to create a dummy array on each device and check their __dlpack_device__ values. If they match, it should work, ok. But then what about CUDA vs CUDA_MANAGED? CPU vs CPU_PINNED? Other values? It feels like there needs to be something in the inspection capabilities to tell it.

  3. How to tell a supported max_version in __dlpack__. Should max_version=(-111, 42) raise? What about reasonable version values: max_version=(3, 0) if a library only supports (2,0).

@ev-br ev-br mentioned this pull request Mar 23, 2026
42 tasks
@rgommers
Copy link
Member

(2) JAX doesn't implement DLPack 1.0 yet

(4) that shouldn't be tested, passing in a PyCapsule is not regular usage.

(5) probably a try-except indeed, given that not every device of every library will be compatible. seems fine to do it like that. I'd expect __dlpack_device__ to also raise then, but not sure all libraries actually do that.

(6) I think only matching devices are guaranteed to work, anything else probably won't but there are likely specific combos that will. Which seems fine (there's nothing in the standard that says it's not fine)?

(7) Whether an unreasonable value raises and what exception type it raises is almost always undefined in the standard. No need to test I'd say. For reasonable values: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html seems unambiguous. If the consumer say (3, 0) and the producer can only do (2, 0), then it should do that and the consumer then may or may not be able to handle that. If it can't, it'll raise a BufferError.

@pearu
Copy link

pearu commented Mar 26, 2026

(2) JAX doesn't implement DLPack 1.0 yet

While this is true (JAX officially supports 0.8, although it could support also 1.0 as the latest XLA ships with dlpack 1.1), the underlying reason for failure in (2) is likely different. For instance, (2) worked fine in past:

>>> import jax.numpy as jnp
>>> x=jnp.from_dlpack(jnp.ones(3), copy=True)
>>> x.device
CudaDevice(id=0)
>>> jax.__version__
'0.8.1.dev20250703+2ac4523a7'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants