2626# THE POSSIBILITY OF SUCH DAMAGE.
2727# *****************************************************************************
2828
29-
30- from typing import Literal , Union
29+ from typing import Literal
3130
3231import dpctl
3332import dpctl .utils as du
3433
34+ import dpnp .tensor as dpt
35+
3536from ._compute_follows_data import (
3637 ExecutionPlacementError ,
3738 get_coerced_usm_type ,
3839 get_execution_queue ,
3940)
4041from ._copy_utils import _empty_like_orderK
41- from ._ctors import empty
42+ from ._ctors import empty_like
43+ from ._scalar_utils import _get_dtype , _get_queue_usm_type , _validate_dtype
4244from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
4345from ._tensor_impl import _take as ti_take
4446from ._tensor_impl import (
4547 default_device_index_type as ti_default_device_index_type ,
4648)
4749from ._tensor_sorting_impl import _searchsorted_left , _searchsorted_right
48- from ._type_utils import isdtype , result_type
50+ from ._type_utils import (
51+ _resolve_weak_types_all_py_ints ,
52+ _to_device_supported_dtype ,
53+ isdtype ,
54+ )
4955from ._usmarray import usm_ndarray
5056
5157
5258def searchsorted (
5359 x1 : usm_ndarray ,
54- x2 : usm_ndarray ,
60+ x2 : usm_ndarray | int | float | complex | bool ,
5561 / ,
5662 * ,
5763 side : Literal ["left" , "right" ] = "left" ,
58- sorter : Union [ usm_ndarray , None ] = None ,
64+ sorter : usm_ndarray | None = None ,
5965) -> usm_ndarray :
6066 """searchsorted(x1, x2, side='left', sorter=None)
6167
@@ -68,8 +74,8 @@ def searchsorted(
6874 input array. Must be a one-dimensional array. If `sorter` is
6975 `None`, must be sorted in ascending order; otherwise, `sorter` must
7076 be an array of indices that sort `x1` in ascending order.
71- x2 (usm_ndarray):
72- array containing search values.
77+ x2 (usm_ndarray | int | float | complex | bool ):
78+ search value or values.
7379 side (Literal["left", "right]):
7480 argument controlling which index is returned if a value lands
7581 exactly on an edge. If `x2` is an array of rank `N` where
@@ -85,13 +91,11 @@ def searchsorted(
8591 array of indices that sort `x1` in ascending order. The array must
8692 have the same shape as `x1` and have an integral data type.
8793 Out of bound index values of `sorter` array are treated using
88- `"wrap"` mode documented in :py:func:`dpctl .tensor.take`.
94+ `"wrap"` mode documented in :py:func:`dpnp .tensor.take`.
8995 Default: `None`.
9096 """
9197 if not isinstance (x1 , usm_ndarray ):
9298 raise TypeError (f"Expected dpnp.tensor.usm_ndarray, got { type (x1 )} " )
93- if not isinstance (x2 , usm_ndarray ):
94- raise TypeError (f"Expected dpnp.tensor.usm_ndarray, got { type (x2 )} " )
9599 if sorter is not None and not isinstance (sorter , usm_ndarray ):
96100 raise TypeError (f"Expected dpnp.tensor.usm_ndarray, got { type (sorter )} " )
97101
@@ -101,23 +105,39 @@ def searchsorted(
101105 "Expected either 'left' or 'right'"
102106 )
103107
104- if sorter is None :
105- q = get_execution_queue ([x1 .sycl_queue , x2 .sycl_queue ])
106- else :
107- q = get_execution_queue (
108- [x1 .sycl_queue , x2 .sycl_queue , sorter .sycl_queue ]
109- )
108+ q1 , x1_usm_type = x1 .sycl_queue , x1 .usm_type
109+ q2 , x2_usm_type = _get_queue_usm_type (x2 )
110+ q3 = sorter .sycl_queue if sorter is not None else None
111+ q = get_execution_queue (tuple (q for q in (q1 , q2 , q3 ) if q is not None ))
110112 if q is None :
111113 raise ExecutionPlacementError (
112114 "Execution placement can not be unambiguously "
113115 "inferred from input arguments."
114116 )
115117
118+ res_usm_type = get_coerced_usm_type (
119+ tuple (
120+ ut
121+ for ut in (
122+ x1_usm_type ,
123+ x2_usm_type ,
124+ )
125+ if ut is not None
126+ )
127+ )
128+ dpt .validate_usm_type (res_usm_type , allow_none = False )
129+ sycl_dev = q .sycl_device
130+
116131 if x1 .ndim != 1 :
117132 raise ValueError ("First argument array must be one-dimensional" )
118133
119134 x1_dt = x1 .dtype
120- x2_dt = x2 .dtype
135+ x2_dt = _get_dtype (x2 , sycl_dev )
136+ if not _validate_dtype (x2_dt ):
137+ raise ValueError (
138+ "dpt.searchsorted search value argument has "
139+ f"unsupported data type { x2_dt } "
140+ )
121141
122142 _manager = du .SequentialOrderManager [q ]
123143 dep_evs = _manager .submitted_events
@@ -132,7 +152,7 @@ def searchsorted(
132152 "Sorter array must be one-dimension with the same "
133153 "shape as the first argument array"
134154 )
135- res = empty (x1 . shape , dtype = x1_dt , usm_type = x1 . usm_type , sycl_queue = q )
155+ res = empty_like (x1 )
136156 ind = (sorter ,)
137157 axis = 0
138158 wrap_out_of_bound_indices_mode = 0
@@ -148,29 +168,29 @@ def searchsorted(
148168 x1 = res
149169 _manager .add_event_pair (ht_ev , ev )
150170
151- if x1_dt != x2_dt :
152- dt = result_type (x1 , x2 )
153- if x1_dt != dt :
154- x1_buf = _empty_like_orderK (x1 , dt )
155- dep_evs = _manager .submitted_events
156- ht_ev , ev = ti_copy (
157- src = x1 , dst = x1_buf , sycl_queue = q , depends = dep_evs
158- )
159- _manager .add_event_pair (ht_ev , ev )
160- x1 = x1_buf
161- if x2_dt != dt :
162- x2_buf = _empty_like_orderK (x2 , dt )
163- dep_evs = _manager .submitted_events
164- ht_ev , ev = ti_copy (
165- src = x2 , dst = x2_buf , sycl_queue = q , depends = dep_evs
166- )
167- _manager .add_event_pair (ht_ev , ev )
168- x2 = x2_buf
171+ dt1 , dt2 = _resolve_weak_types_all_py_ints (x1_dt , x2_dt , sycl_dev )
172+ dt = _to_device_supported_dtype (dpt .result_type (dt1 , dt2 ), sycl_dev )
173+
174+ if not isinstance (x2 , usm_ndarray ):
175+ x2 = dpt .asarray (x2 , dtype = dt2 , usm_type = res_usm_type , sycl_queue = q )
176+
177+ # get submitted events again in case some were added by sorter handling
178+ dep_evs = _manager .submitted_events
179+ if x1_dt != dt :
180+ x1_buf = _empty_like_orderK (x1 , dt )
181+ ht_ev , ev = ti_copy (src = x1 , dst = x1_buf , sycl_queue = q , depends = dep_evs )
182+ _manager .add_event_pair (ht_ev , ev )
183+ x1 = x1_buf
184+
185+ if x2 .dtype != dt :
186+ x2_buf = _empty_like_orderK (x2 , dt )
187+ ht_ev , ev = ti_copy (src = x2 , dst = x2_buf , sycl_queue = q , depends = dep_evs )
188+ _manager .add_event_pair (ht_ev , ev )
189+ x2 = x2_buf
169190
170- dst_usm_type = get_coerced_usm_type ([x1 .usm_type , x2 .usm_type ])
171191 index_dt = ti_default_device_index_type (q )
172192
173- dst = _empty_like_orderK (x2 , index_dt , usm_type = dst_usm_type )
193+ dst = _empty_like_orderK (x2 , index_dt , usm_type = res_usm_type )
174194
175195 dep_evs = _manager .submitted_events
176196 if side == "left" :
0 commit comments