@@ -332,17 +332,19 @@ def argsort(
332
332
** kwargs ,
333
333
) -> ndarray :
334
334
# Note: this keyword argument is different, and the default is different.
335
- kind = "stable" if stable else "quicksort"
335
+ # We set it in kwargs like this because numpy.sort uses kind='quicksort'
336
+ # as the default whereas cupy.sort uses kind=None.
337
+ if stable :
338
+ kwargs ['kind' ] = "stable"
336
339
if not descending :
337
- res = xp .argsort (x , axis = axis , kind = kind , ** kwargs )
340
+ res = xp .argsort (x , axis = axis , ** kwargs )
338
341
else :
339
342
# As NumPy has no native descending sort, we imitate it here. Note that
340
343
# simply flipping the results of xp.argsort(x, ...) would not
341
344
# respect the relative order like it would in native descending sorts.
342
345
res = xp .flip (
343
- xp .argsort (xp .flip (x , axis = axis ), axis = axis , kind = kind ),
346
+ xp .argsort (xp .flip (x , axis = axis ), axis = axis , ** kwargs ),
344
347
axis = axis ,
345
- ** kwargs ,
346
348
)
347
349
# Rely on flip()/argsort() to validate axis
348
350
normalised_axis = axis if axis >= 0 else x .ndim + axis
@@ -355,8 +357,11 @@ def sort(
355
357
** kwargs ,
356
358
) -> ndarray :
357
359
# Note: this keyword argument is different, and the default is different.
358
- kind = "stable" if stable else "quicksort"
359
- res = xp .sort (x , axis = axis , kind = kind , ** kwargs )
360
+ # We set it in kwargs like this because numpy.sort uses kind='quicksort'
361
+ # as the default whereas cupy.sort uses kind=None.
362
+ if stable :
363
+ kwargs ['kind' ] = "stable"
364
+ res = xp .sort (x , axis = axis , ** kwargs )
360
365
if descending :
361
366
res = xp .flip (res , axis = axis )
362
367
return res
0 commit comments