Skip to content

Commit a231d56

Browse files
committed
Sorting, unique functions properly handle 0D arrays
1 parent 4ecf30a commit a231d56

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

dpctl/tensor/_set_functions_async.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,16 @@ def unique_inverse(x):
292292
array_api_dev = x.device
293293
exec_q = array_api_dev.sycl_queue
294294
x_usm_type = x.usm_type
295-
if x.ndim == 1:
295+
ind_dt = default_device_index_type(exec_q)
296+
if x.ndim == 0:
297+
return UniqueInverseResult(
298+
dpt.reshape(x, (1,), order="C", copy=True),
299+
dpt.zeros_like(x, ind_dt, usm_type=x_usm_type, sycl_queue=exec_q),
300+
)
301+
elif x.ndim == 1:
296302
fx = x
297303
else:
298304
fx = dpt.reshape(x, (x.size,), order="C")
299-
ind_dt = default_device_index_type(exec_q)
300305
sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C")
301306
unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C")
302307
if fx.size == 0:
@@ -456,11 +461,21 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
456461
array_api_dev = x.device
457462
exec_q = array_api_dev.sycl_queue
458463
x_usm_type = x.usm_type
459-
if x.ndim == 1:
464+
ind_dt = default_device_index_type(exec_q)
465+
if x.ndim == 0:
466+
uv = dpt.reshape(x, (1,), order="C", copy=True)
467+
return UniqueAllResult(
468+
uv,
469+
dpt.zeros_like(uv, ind_dt, usm_type=x_usm_type, sycl_queue=exec_q),
470+
dpt.zeros_like(x, ind_dt, usm_type=x_usm_type, sycl_queue=exec_q),
471+
dpt.ones_like(
472+
uv, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
473+
),
474+
)
475+
elif x.ndim == 1:
460476
fx = x
461477
else:
462478
fx = dpt.reshape(x, (x.size,), order="C")
463-
ind_dt = default_device_index_type(exec_q)
464479
sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C")
465480
unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C")
466481
if fx.size == 0:

dpctl/tensor/_sorting.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ def sort(x, /, *, axis=-1, descending=False, stable=False):
6161
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
6262
)
6363
nd = x.ndim
64-
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
64+
if nd == 0:
65+
axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
66+
return dpt.copy(x, order="C")
67+
else:
68+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
6569
a1 = axis + 1
6670
if a1 == nd:
6771
perm = list(range(nd))
@@ -134,7 +138,13 @@ def argsort(x, axis=-1, descending=False, stable=False):
134138
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
135139
)
136140
nd = x.ndim
137-
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
141+
if nd == 0:
142+
axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
143+
return dpt.zeros_like(
144+
x, dtype=ti.default_device_index_type(x.sycl_queue), order="C"
145+
)
146+
else:
147+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
138148
a1 = axis + 1
139149
if a1 == nd:
140150
perm = list(range(nd))

0 commit comments

Comments
 (0)