Skip to content

Commit 60c6ad6

Browse files
committed
unique_all and unique_inverse inverse indices shape fixed
Were previously returning a 1D array of indices rather than an array with the same shape as input `x`
1 parent a231d56 commit 60c6ad6

File tree

2 files changed

+15
-23
lines changed

2 files changed

+15
-23
lines changed

dpctl/tensor/_set_functions_async.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -293,19 +293,14 @@ def unique_inverse(x):
293293
exec_q = array_api_dev.sycl_queue
294294
x_usm_type = x.usm_type
295295
ind_dt = default_device_index_type(exec_q)
296-
if x.ndim == 0:
297-
return UniqueInverseResult(
298-
dpt.reshape(x, (1,), order="C", copy=True),
299-
dpt.zeros_like(x, ind_dt, usm_type=x_usm_type, sycl_queue=exec_q),
300-
)
301-
elif x.ndim == 1:
296+
if x.ndim == 1:
302297
fx = x
303298
else:
304299
fx = dpt.reshape(x, (x.size,), order="C")
305300
sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C")
306301
unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C")
307302
if fx.size == 0:
308-
return UniqueInverseResult(fx, unsorting_ids)
303+
return UniqueInverseResult(fx, dpt.reshape(unsorting_ids, x.shape))
309304
host_tasks = []
310305
if fx.flags.c_contiguous:
311306
ht_ev, sort_ev = _argsort_ascending(
@@ -366,7 +361,7 @@ def unique_inverse(x):
366361
)
367362
if n_uniques == fx.size:
368363
dpctl.SyclEvent.wait_for(host_tasks)
369-
return UniqueInverseResult(s, unsorting_ids)
364+
return UniqueInverseResult(s, dpt.reshape(unsorting_ids, x.shape))
370365
unique_vals = dpt.empty(
371366
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
372367
)
@@ -422,7 +417,9 @@ def unique_inverse(x):
422417
pos = pos_next
423418
host_tasks.append(ht_ev)
424419
dpctl.SyclEvent.wait_for(host_tasks)
425-
return UniqueInverseResult(unique_vals, inv[unsorting_ids])
420+
return UniqueInverseResult(
421+
unique_vals, dpt.reshape(inv[unsorting_ids], x.shape)
422+
)
426423

427424

428425
def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
@@ -462,17 +459,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
462459
exec_q = array_api_dev.sycl_queue
463460
x_usm_type = x.usm_type
464461
ind_dt = default_device_index_type(exec_q)
465-
if x.ndim == 0:
466-
uv = dpt.reshape(x, (1,), order="C", copy=True)
467-
return UniqueAllResult(
468-
uv,
469-
dpt.zeros_like(uv, ind_dt, usm_type=x_usm_type, sycl_queue=exec_q),
470-
dpt.zeros_like(x, ind_dt, usm_type=x_usm_type, sycl_queue=exec_q),
471-
dpt.ones_like(
472-
uv, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
473-
),
474-
)
475-
elif x.ndim == 1:
462+
if x.ndim == 1:
476463
fx = x
477464
else:
478465
fx = dpt.reshape(x, (x.size,), order="C")
@@ -482,7 +469,10 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
482469
# original array contains no data
483470
# so it can be safely returned as values
484471
return UniqueAllResult(
485-
fx, sorting_ids, unsorting_ids, dpt.empty_like(fx, dtype=ind_dt)
472+
fx,
473+
sorting_ids,
474+
dpt.reshape(unsorting_ids, x.shape),
475+
dpt.empty_like(fx, dtype=ind_dt),
486476
)
487477
host_tasks = []
488478
if fx.flags.c_contiguous:
@@ -550,7 +540,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
550540
return UniqueAllResult(
551541
s,
552542
sorting_ids,
553-
unsorting_ids,
543+
dpt.reshape(unsorting_ids, x.shape),
554544
_counts,
555545
)
556546
unique_vals = dpt.empty(
@@ -611,6 +601,6 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
611601
return UniqueAllResult(
612602
unique_vals,
613603
sorting_ids[cum_unique_counts[:-1]],
614-
inv[unsorting_ids],
604+
dpt.reshape(inv[unsorting_ids], x.shape),
615605
_counts,
616606
)

dpctl/tests/test_usm_ndarray_unique.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def test_unique_inverse(dtype):
117117
uv, inv = dpt.unique_inverse(inp)
118118
assert dpt.all(uv == dpt.arange(2, dtype=dtype))
119119
assert dpt.all(inp == uv[inv])
120+
assert inp.shape == inv.shape
120121

121122

122123
@pytest.mark.parametrize(
@@ -151,6 +152,7 @@ def test_unique_all(dtype):
151152
assert dpt.all(uv == dpt.arange(2, dtype=dtype))
152153
assert dpt.all(uv == inp[ind])
153154
assert dpt.all(inp == uv[inv])
155+
assert inp.shape == inv.shape
154156
assert dpt.all(uv_counts == dpt.full(2, n, dtype=uv_counts.dtype))
155157

156158

0 commit comments

Comments
 (0)