Conversation
|
Issues with these "trivial" tests:
|
|
Larger issues:
|
|
(2) JAX doesn't implement DLPack 1.0 yet (4) that shouldn't be tested, passing in a PyCapsule is not regular usage. (5) probably a try-except indeed, given that not every device of every library will be compatible. seems fine to do it like that. I'd expect (6) I think only matching devices are guaranteed to work, anything else probably won't but there are likely specific combos that will. Which seems fine (there's nothing in the standard that says it's not fine)? (7) Whether an unreasonable value raises and what exception type it raises is almost always undefined in the standard. No need to test I'd say. For reasonable values: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html seems unambiguous. If the consumer say |
While this is true (JAX officially supports 0.8, although it could support also 1.0 as the latest XLA ships with dlpack 1.1), the underlying reason for failure in (2) is likely different. For instance, (2) worked fine in past: >>> import jax.numpy as jnp
>>> x=jnp.from_dlpack(jnp.ones(3), copy=True)
>>> x.device
CudaDevice(id=0)
>>> jax.__version__
'0.8.1.dev20250703+2ac4523a7' |
Add a basic set of
dlpacktests. The tests here are "boring":from_dlpackfor a single array librarycopy=Falseon the same deviceno tests forseveral__dlpack__yetXXXcomments to mark where I've no idea.