Skip to content

Commit fc69816

Browse files
committed
alllow scalars in searchsorted second argument
1 parent 5e823b1 commit fc69816

2 files changed

Lines changed: 122 additions & 84 deletions

File tree

dpnp/tensor/_searchsorted.py

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,42 @@
2626
# THE POSSIBILITY OF SUCH DAMAGE.
2727
# *****************************************************************************
2828

29-
30-
from typing import Literal, Union
29+
from typing import Literal
3130

3231
import dpctl
3332
import dpctl.utils as du
3433

34+
import dpnp.tensor as dpt
35+
3536
from ._compute_follows_data import (
3637
ExecutionPlacementError,
3738
get_coerced_usm_type,
3839
get_execution_queue,
3940
)
4041
from ._copy_utils import _empty_like_orderK
41-
from ._ctors import empty
42+
from ._ctors import empty_like
43+
from ._scalar_utils import _get_dtype, _get_queue_usm_type, _validate_dtype
4244
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
4345
from ._tensor_impl import _take as ti_take
4446
from ._tensor_impl import (
4547
default_device_index_type as ti_default_device_index_type,
4648
)
4749
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
48-
from ._type_utils import isdtype, result_type
50+
from ._type_utils import (
51+
_resolve_weak_types_all_py_ints,
52+
_to_device_supported_dtype,
53+
isdtype,
54+
)
4955
from ._usmarray import usm_ndarray
5056

5157

5258
def searchsorted(
5359
x1: usm_ndarray,
54-
x2: usm_ndarray,
60+
x2: usm_ndarray | int | float | complex | bool,
5561
/,
5662
*,
5763
side: Literal["left", "right"] = "left",
58-
sorter: Union[usm_ndarray, None] = None,
64+
sorter: usm_ndarray | None = None,
5965
) -> usm_ndarray:
6066
"""searchsorted(x1, x2, side='left', sorter=None)
6167
@@ -68,8 +74,8 @@ def searchsorted(
6874
input array. Must be a one-dimensional array. If `sorter` is
6975
`None`, must be sorted in ascending order; otherwise, `sorter` must
7076
be an array of indices that sort `x1` in ascending order.
71-
x2 (usm_ndarray):
72-
array containing search values.
77+
x2 (usm_ndarray | int | float | complex | bool):
78+
search value or values.
7379
side (Literal["left", "right]):
7480
argument controlling which index is returned if a value lands
7581
exactly on an edge. If `x2` is an array of rank `N` where
@@ -85,13 +91,11 @@ def searchsorted(
8591
array of indices that sort `x1` in ascending order. The array must
8692
have the same shape as `x1` and have an integral data type.
8793
Out of bound index values of `sorter` array are treated using
88-
`"wrap"` mode documented in :py:func:`dpctl.tensor.take`.
94+
`"wrap"` mode documented in :py:func:`dpnp.tensor.take`.
8995
Default: `None`.
9096
"""
9197
if not isinstance(x1, usm_ndarray):
9298
raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(x1)}")
93-
if not isinstance(x2, usm_ndarray):
94-
raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(x2)}")
9599
if sorter is not None and not isinstance(sorter, usm_ndarray):
96100
raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(sorter)}")
97101

@@ -101,23 +105,39 @@ def searchsorted(
101105
"Expected either 'left' or 'right'"
102106
)
103107

104-
if sorter is None:
105-
q = get_execution_queue([x1.sycl_queue, x2.sycl_queue])
106-
else:
107-
q = get_execution_queue(
108-
[x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue]
109-
)
108+
q1, x1_usm_type = x1.sycl_queue, x1.usm_type
109+
q2, x2_usm_type = _get_queue_usm_type(x2)
110+
q3 = sorter.sycl_queue if sorter is not None else None
111+
q = get_execution_queue(tuple(q for q in (q1, q2, q3) if q is not None))
110112
if q is None:
111113
raise ExecutionPlacementError(
112114
"Execution placement can not be unambiguously "
113115
"inferred from input arguments."
114116
)
115117

118+
res_usm_type = get_coerced_usm_type(
119+
tuple(
120+
ut
121+
for ut in (
122+
x1_usm_type,
123+
x2_usm_type,
124+
)
125+
if ut is not None
126+
)
127+
)
128+
dpt.validate_usm_type(res_usm_type, allow_none=False)
129+
sycl_dev = q.sycl_device
130+
116131
if x1.ndim != 1:
117132
raise ValueError("First argument array must be one-dimensional")
118133

119134
x1_dt = x1.dtype
120-
x2_dt = x2.dtype
135+
x2_dt = _get_dtype(x2, sycl_dev)
136+
if not _validate_dtype(x2_dt):
137+
raise ValueError(
138+
"dpt.searchsorted search value argument has "
139+
f"unsupported data type {x2_dt}"
140+
)
121141

122142
_manager = du.SequentialOrderManager[q]
123143
dep_evs = _manager.submitted_events
@@ -132,7 +152,7 @@ def searchsorted(
132152
"Sorter array must be one-dimension with the same "
133153
"shape as the first argument array"
134154
)
135-
res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q)
155+
res = empty_like(x1)
136156
ind = (sorter,)
137157
axis = 0
138158
wrap_out_of_bound_indices_mode = 0
@@ -148,29 +168,29 @@ def searchsorted(
148168
x1 = res
149169
_manager.add_event_pair(ht_ev, ev)
150170

151-
if x1_dt != x2_dt:
152-
dt = result_type(x1, x2)
153-
if x1_dt != dt:
154-
x1_buf = _empty_like_orderK(x1, dt)
155-
dep_evs = _manager.submitted_events
156-
ht_ev, ev = ti_copy(
157-
src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs
158-
)
159-
_manager.add_event_pair(ht_ev, ev)
160-
x1 = x1_buf
161-
if x2_dt != dt:
162-
x2_buf = _empty_like_orderK(x2, dt)
163-
dep_evs = _manager.submitted_events
164-
ht_ev, ev = ti_copy(
165-
src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs
166-
)
167-
_manager.add_event_pair(ht_ev, ev)
168-
x2 = x2_buf
171+
dt1, dt2 = _resolve_weak_types_all_py_ints(x1_dt, x2_dt, sycl_dev)
172+
dt = _to_device_supported_dtype(dpt.result_type(dt1, dt2), sycl_dev)
173+
174+
if not isinstance(x2, usm_ndarray):
175+
x2 = dpt.asarray(x2, dtype=dt2, usm_type=res_usm_type, sycl_queue=q)
176+
177+
# get submitted events again in case some were added by sorter handling
178+
dep_evs = _manager.submitted_events
179+
if x1_dt != dt:
180+
x1_buf = _empty_like_orderK(x1, dt)
181+
ht_ev, ev = ti_copy(src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs)
182+
_manager.add_event_pair(ht_ev, ev)
183+
x1 = x1_buf
184+
185+
if x2.dtype != dt:
186+
x2_buf = _empty_like_orderK(x2, dt)
187+
ht_ev, ev = ti_copy(src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs)
188+
_manager.add_event_pair(ht_ev, ev)
189+
x2 = x2_buf
169190

170-
dst_usm_type = get_coerced_usm_type([x1.usm_type, x2.usm_type])
171191
index_dt = ti_default_device_index_type(q)
172192

173-
dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)
193+
dst = _empty_like_orderK(x2, index_dt, usm_type=res_usm_type)
174194

175195
dep_evs = _manager.submitted_events
176196
if side == "left":

dpnp/tests/tensor/test_usm_ndarray_searchsorted.py

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
# THE POSSIBILITY OF SUCH DAMAGE.
2727
# *****************************************************************************
2828

29+
import ctypes
30+
2931
import dpctl
3032
import numpy as np
3133
import pytest
@@ -37,6 +39,30 @@
3739
skip_if_dtype_not_supported,
3840
)
3941

42+
_integer_dtypes = [
43+
"i1",
44+
"u1",
45+
"i2",
46+
"u2",
47+
"i4",
48+
"u4",
49+
"i8",
50+
"u8",
51+
]
52+
53+
_floating_dtypes = [
54+
"f2",
55+
"f4",
56+
"f8",
57+
]
58+
59+
_complex_dtypes = [
60+
"c8",
61+
"c16",
62+
]
63+
64+
_all_dtypes = ["?"] + _integer_dtypes + _floating_dtypes + _complex_dtypes
65+
4066

4167
def _check(hay_stack, needles, needles_np):
4268
assert hay_stack.dtype == needles.dtype
@@ -103,19 +129,7 @@ def test_searchsorted_strided_bool():
103129
)
104130

105131

106-
@pytest.mark.parametrize(
107-
"idt",
108-
[
109-
dpt.int8,
110-
dpt.uint8,
111-
dpt.int16,
112-
dpt.uint16,
113-
dpt.int32,
114-
dpt.uint32,
115-
dpt.int64,
116-
dpt.uint64,
117-
],
118-
)
132+
@pytest.mark.parametrize("idt", _integer_dtypes)
119133
def test_searchsorted_contig_int(idt):
120134
q = get_queue_or_skip()
121135
skip_if_dtype_not_supported(idt, q)
@@ -135,19 +149,7 @@ def test_searchsorted_contig_int(idt):
135149
)
136150

137151

138-
@pytest.mark.parametrize(
139-
"idt",
140-
[
141-
dpt.int8,
142-
dpt.uint8,
143-
dpt.int16,
144-
dpt.uint16,
145-
dpt.int32,
146-
dpt.uint32,
147-
dpt.int64,
148-
dpt.uint64,
149-
],
150-
)
152+
@pytest.mark.parametrize("idt", _integer_dtypes)
151153
def test_searchsorted_strided_int(idt):
152154
q = get_queue_or_skip()
153155
skip_if_dtype_not_supported(idt, q)
@@ -174,12 +176,12 @@ def _add_extended_fp(array):
174176
array[-1] = dpt.nan
175177

176178

177-
@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
178-
def test_searchsorted_contig_fp(idt):
179+
@pytest.mark.parametrize("fdt", _floating_dtypes)
180+
def test_searchsorted_contig_fp(fdt):
179181
q = get_queue_or_skip()
180-
skip_if_dtype_not_supported(idt, q)
182+
skip_if_dtype_not_supported(fdt, q)
181183

182-
dt = dpt.dtype(idt)
184+
dt = dpt.dtype(fdt)
183185

184186
hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
185187
_add_extended_fp(hay_stack)
@@ -195,12 +197,12 @@ def test_searchsorted_contig_fp(idt):
195197
)
196198

197199

198-
@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
199-
def test_searchsorted_strided_fp(idt):
200+
@pytest.mark.parametrize("fdt", _floating_dtypes)
201+
def test_searchsorted_strided_fp(fdt):
200202
q = get_queue_or_skip()
201-
skip_if_dtype_not_supported(idt, q)
203+
skip_if_dtype_not_supported(fdt, q)
202204

203-
dt = dpt.dtype(idt)
205+
dt = dpt.dtype(fdt)
204206

205207
hay_stack = dpt.repeat(
206208
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
@@ -243,12 +245,12 @@ def _add_extended_cfp(array):
243245
return dpt.sort(dpt.concat((ev, array)))
244246

245247

246-
@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
247-
def test_searchsorted_contig_cfp(idt):
248+
@pytest.mark.parametrize("cdt", _complex_dtypes)
249+
def test_searchsorted_contig_cfp(cdt):
248250
q = get_queue_or_skip()
249-
skip_if_dtype_not_supported(idt, q)
251+
skip_if_dtype_not_supported(cdt, q)
250252

251-
dt = dpt.dtype(idt)
253+
dt = dpt.dtype(cdt)
252254

253255
hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
254256
hay_stack = _add_extended_cfp(hay_stack)
@@ -263,12 +265,12 @@ def test_searchsorted_contig_cfp(idt):
263265
)
264266

265267

266-
@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
267-
def test_searchsorted_strided_cfp(idt):
268+
@pytest.mark.parametrize("cdt", _complex_dtypes)
269+
def test_searchsorted_strided_cfp(cdt):
268270
q = get_queue_or_skip()
269-
skip_if_dtype_not_supported(idt, q)
271+
skip_if_dtype_not_supported(cdt, q)
270272

271-
dt = dpt.dtype(idt)
273+
dt = dpt.dtype(cdt)
272274

273275
hay_stack = dpt.repeat(
274276
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
@@ -315,7 +317,7 @@ def test_searchsorted_validation():
315317
x1 = dpt.arange(10, dtype="i4")
316318
except dpctl.SyclDeviceCreationError:
317319
pytest.skip("Default device could not be created")
318-
with pytest.raises(TypeError):
320+
with pytest.raises(ValueError):
319321
dpt.searchsorted(x1, None)
320322
with pytest.raises(TypeError):
321323
dpt.searchsorted(x1, x1, sorter=dict())
@@ -333,10 +335,10 @@ def test_searchsorted_validation2():
333335
q2 = dpctl.SyclQueue(d, property="in_order")
334336
x2 = dpt.ones(5, dtype=x1.dtype, sycl_queue=q2)
335337

336-
with pytest.raises(dpt.ExecutionPlacementError):
338+
with pytest.raises(dpu.ExecutionPlacementError):
337339
dpt.searchsorted(x1, x2)
338340

339-
with pytest.raises(dpt.ExecutionPlacementError):
341+
with pytest.raises(dpu.ExecutionPlacementError):
340342
dpt.searchsorted(x1, x2, sorter=sorter)
341343

342344
sorter = dpt.ones(x1.shape, dtype=dpt.bool)
@@ -405,3 +407,19 @@ def test_searchsorted_strided_scalar_needle():
405407
needles = dpt.asarray(needles_np)
406408

407409
_check(hay_stack, needles, needles_np)
410+
411+
412+
@pytest.mark.parametrize(
413+
"py_zero",
414+
[bool(0), int(0), float(0), complex(0), np.float32(0), ctypes.c_int(0)],
415+
)
416+
@pytest.mark.parametrize("dt", _all_dtypes)
417+
def test_searchsorted_py_scalars(py_zero, dt):
418+
q = get_queue_or_skip()
419+
skip_if_dtype_not_supported(dt, q)
420+
421+
x = dpt.zeros(10, dtype=dt, sycl_queue=q)
422+
423+
r1 = dpt.searchsorted(x, py_zero)
424+
assert isinstance(r1, dpt.usm_ndarray)
425+
assert r1.shape == ()

0 commit comments

Comments
 (0)