Skip to content

Commit 111a122

Browse files
committed
Fix sort() and argsort() with cupy
1 parent c3eb0d5 commit 111a122

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

array_api_compat/common/_aliases.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,17 +332,19 @@ def argsort(
332332
**kwargs,
333333
) -> ndarray:
334334
# Note: this keyword argument is different, and the default is different.
335-
kind = "stable" if stable else "quicksort"
335+
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
336+
# as the default whereas cupy.sort uses kind=None.
337+
if stable:
338+
kwargs['kind'] = "stable"
336339
if not descending:
337-
res = xp.argsort(x, axis=axis, kind=kind, **kwargs)
340+
res = xp.argsort(x, axis=axis, **kwargs)
338341
else:
339342
# As NumPy has no native descending sort, we imitate it here. Note that
340343
# simply flipping the results of xp.argsort(x, ...) would not
341344
# respect the relative order like it would in native descending sorts.
342345
res = xp.flip(
343-
xp.argsort(xp.flip(x, axis=axis), axis=axis, kind=kind),
346+
xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs),
344347
axis=axis,
345-
**kwargs,
346348
)
347349
# Rely on flip()/argsort() to validate axis
348350
normalised_axis = axis if axis >= 0 else x.ndim + axis
@@ -355,8 +357,11 @@ def sort(
355357
**kwargs,
356358
) -> ndarray:
357359
# Note: this keyword argument is different, and the default is different.
358-
kind = "stable" if stable else "quicksort"
359-
res = xp.sort(x, axis=axis, kind=kind, **kwargs)
360+
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
361+
# as the default whereas cupy.sort uses kind=None.
362+
if stable:
363+
kwargs['kind'] = "stable"
364+
res = xp.sort(x, axis=axis, **kwargs)
360365
if descending:
361366
res = xp.flip(res, axis=axis)
362367
return res

0 commit comments

Comments
 (0)