@@ -293,19 +293,14 @@ def unique_inverse(x):
293
293
exec_q = array_api_dev .sycl_queue
294
294
x_usm_type = x .usm_type
295
295
ind_dt = default_device_index_type (exec_q )
296
- if x .ndim == 0 :
297
- return UniqueInverseResult (
298
- dpt .reshape (x , (1 ,), order = "C" , copy = True ),
299
- dpt .zeros_like (x , ind_dt , usm_type = x_usm_type , sycl_queue = exec_q ),
300
- )
301
- elif x .ndim == 1 :
296
+ if x .ndim == 1 :
302
297
fx = x
303
298
else :
304
299
fx = dpt .reshape (x , (x .size ,), order = "C" )
305
300
sorting_ids = dpt .empty_like (fx , dtype = ind_dt , order = "C" )
306
301
unsorting_ids = dpt .empty_like (sorting_ids , dtype = ind_dt , order = "C" )
307
302
if fx .size == 0 :
308
- return UniqueInverseResult (fx , unsorting_ids )
303
+ return UniqueInverseResult (fx , dpt . reshape ( unsorting_ids , x . shape ) )
309
304
host_tasks = []
310
305
if fx .flags .c_contiguous :
311
306
ht_ev , sort_ev = _argsort_ascending (
@@ -366,7 +361,7 @@ def unique_inverse(x):
366
361
)
367
362
if n_uniques == fx .size :
368
363
dpctl .SyclEvent .wait_for (host_tasks )
369
- return UniqueInverseResult (s , unsorting_ids )
364
+ return UniqueInverseResult (s , dpt . reshape ( unsorting_ids , x . shape ) )
370
365
unique_vals = dpt .empty (
371
366
n_uniques , dtype = x .dtype , usm_type = x_usm_type , sycl_queue = exec_q
372
367
)
@@ -422,7 +417,9 @@ def unique_inverse(x):
422
417
pos = pos_next
423
418
host_tasks .append (ht_ev )
424
419
dpctl .SyclEvent .wait_for (host_tasks )
425
- return UniqueInverseResult (unique_vals , inv [unsorting_ids ])
420
+ return UniqueInverseResult (
421
+ unique_vals , dpt .reshape (inv [unsorting_ids ], x .shape )
422
+ )
426
423
427
424
428
425
def unique_all (x : dpt .usm_ndarray ) -> UniqueAllResult :
@@ -462,17 +459,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
462
459
exec_q = array_api_dev .sycl_queue
463
460
x_usm_type = x .usm_type
464
461
ind_dt = default_device_index_type (exec_q )
465
- if x .ndim == 0 :
466
- uv = dpt .reshape (x , (1 ,), order = "C" , copy = True )
467
- return UniqueAllResult (
468
- uv ,
469
- dpt .zeros_like (uv , ind_dt , usm_type = x_usm_type , sycl_queue = exec_q ),
470
- dpt .zeros_like (x , ind_dt , usm_type = x_usm_type , sycl_queue = exec_q ),
471
- dpt .ones_like (
472
- uv , dtype = ind_dt , usm_type = x_usm_type , sycl_queue = exec_q
473
- ),
474
- )
475
- elif x .ndim == 1 :
462
+ if x .ndim == 1 :
476
463
fx = x
477
464
else :
478
465
fx = dpt .reshape (x , (x .size ,), order = "C" )
@@ -482,7 +469,10 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
482
469
# original array contains no data
483
470
# so it can be safely returned as values
484
471
return UniqueAllResult (
485
- fx , sorting_ids , unsorting_ids , dpt .empty_like (fx , dtype = ind_dt )
472
+ fx ,
473
+ sorting_ids ,
474
+ dpt .reshape (unsorting_ids , x .shape ),
475
+ dpt .empty_like (fx , dtype = ind_dt ),
486
476
)
487
477
host_tasks = []
488
478
if fx .flags .c_contiguous :
@@ -550,7 +540,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
550
540
return UniqueAllResult (
551
541
s ,
552
542
sorting_ids ,
553
- unsorting_ids ,
543
+ dpt . reshape ( unsorting_ids , x . shape ) ,
554
544
_counts ,
555
545
)
556
546
unique_vals = dpt .empty (
@@ -611,6 +601,6 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
611
601
return UniqueAllResult (
612
602
unique_vals ,
613
603
sorting_ids [cum_unique_counts [:- 1 ]],
614
- inv [unsorting_ids ],
604
+ dpt . reshape ( inv [unsorting_ids ], x . shape ) ,
615
605
_counts ,
616
606
)
0 commit comments