Skip to content

Commit f68dc36

Browse files
committed
fix: ->jvm on object-dtype numpy arrays with np-array bindings
The np-array zero-copy bindings register a pyobject->jvm :ndarray method that always maps the numpy dtype to a native tensor datatype. Object and other non-numeric dtypes have no such mapping, so ->jvm threw "Unable to find datatype: object" once the bindings were loaded. Route non-numeric dtypes to the default element-wise copy conversion instead. Closes #187
1 parent 1248845 commit f68dc36

2 files changed

Lines changed: 26 additions & 4 deletions

File tree

src/libpython_clj2/python/np_array.clj

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,25 @@
7373
:ctypes ctypes})))
7474

7575

76+
(defn- zero-copyable-dtype?
77+
"Object/str/datetime arrays hold python-object pointers, not a numeric native
78+
buffer, so only the numeric dtypes in py-dtype->dtype-map can become a tensor."
79+
[pyobj]
80+
(py-ffi/with-gil
81+
(let [dtype-name (-> (py-proto/get-attr pyobj "dtype")
82+
(py-proto/get-attr "name"))]
83+
(contains? py-dtype->dtype-map dtype-name))))
84+
85+
7686
(defmethod py-proto/pyobject->jvm :ndarray
7787
[pyobj opts]
78-
(pygc/with-stack-context
79-
(-> (numpy->desc pyobj)
80-
(dtt/nd-buffer-descriptor->tensor)
81-
(dtt/clone))))
88+
(if (zero-copyable-dtype? pyobj)
89+
(pygc/with-stack-context
90+
(-> (numpy->desc pyobj)
91+
(dtt/nd-buffer-descriptor->tensor)
92+
(dtt/clone)))
93+
;; non-numeric dtype (object/str/datetime): can't zero-copy, fall back to copy
94+
((get-method py-proto/pyobject->jvm :default) pyobj opts)))
8295

8396

8497
(defmethod py-proto/pyobject-as-jvm :ndarray

test/libpython_clj2/numpy_test.clj

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
(ns libpython-clj2.numpy-test
22
(:require [clojure.test :refer [deftest is]]
33
[libpython-clj2.python :as py]
4+
;; loading these zero-copy bindings is what made object-dtype ->jvm fail
5+
[libpython-clj2.python.np-array]
46
[tech.v3.datatype :as dtype]
57
[tech.v3.datatype.functional :as dfn]
68
[tech.v3.tensor :as dtt]))
@@ -18,3 +20,10 @@
1820
(is (dfn/equals (dtt/ensure-tensor np-ary) tens))
1921
(is (dfn/equals [1 2 3 4]
2022
(dtype/make-container :java-array :int64 np-ary)))))
23+
24+
25+
(deftest object-dtype-ndarray->jvm
26+
(let [empty-obj (py/call-attr-kw np-mod "array" [[]] {:dtype "object"})
27+
mixed-obj (py/call-attr-kw np-mod "array" [["a" 1]] {:dtype "object"})]
28+
(is (= [] (py/->jvm empty-obj)))
29+
(is (= ["a" 1] (py/->jvm mixed-obj)))))

0 commit comments

Comments
 (0)