@@ -292,11 +292,16 @@ def unique_inverse(x):
292
292
array_api_dev = x .device
293
293
exec_q = array_api_dev .sycl_queue
294
294
x_usm_type = x .usm_type
295
- if x .ndim == 1 :
295
+ ind_dt = default_device_index_type (exec_q )
296
+ if x .ndim == 0 :
297
+ return UniqueInverseResult (
298
+ dpt .reshape (x , (1 ,), order = "C" , copy = True ),
299
+ dpt .zeros_like (x , ind_dt , usm_type = x_usm_type , sycl_queue = exec_q ),
300
+ )
301
+ elif x .ndim == 1 :
296
302
fx = x
297
303
else :
298
304
fx = dpt .reshape (x , (x .size ,), order = "C" )
299
- ind_dt = default_device_index_type (exec_q )
300
305
sorting_ids = dpt .empty_like (fx , dtype = ind_dt , order = "C" )
301
306
unsorting_ids = dpt .empty_like (sorting_ids , dtype = ind_dt , order = "C" )
302
307
if fx .size == 0 :
@@ -456,11 +461,21 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
456
461
array_api_dev = x .device
457
462
exec_q = array_api_dev .sycl_queue
458
463
x_usm_type = x .usm_type
459
- if x .ndim == 1 :
464
+ ind_dt = default_device_index_type (exec_q )
465
+ if x .ndim == 0 :
466
+ uv = dpt .reshape (x , (1 ,), order = "C" , copy = True )
467
+ return UniqueAllResult (
468
+ uv ,
469
+ dpt .zeros_like (uv , ind_dt , usm_type = x_usm_type , sycl_queue = exec_q ),
470
+ dpt .zeros_like (x , ind_dt , usm_type = x_usm_type , sycl_queue = exec_q ),
471
+ dpt .ones_like (
472
+ uv , dtype = ind_dt , usm_type = x_usm_type , sycl_queue = exec_q
473
+ ),
474
+ )
475
+ elif x .ndim == 1 :
460
476
fx = x
461
477
else :
462
478
fx = dpt .reshape (x , (x .size ,), order = "C" )
463
- ind_dt = default_device_index_type (exec_q )
464
479
sorting_ids = dpt .empty_like (fx , dtype = ind_dt , order = "C" )
465
480
unsorting_ids = dpt .empty_like (sorting_ids , dtype = ind_dt , order = "C" )
466
481
if fx .size == 0 :
0 commit comments