Skip to content

Commit a429352

Browse files
committed
MAINT: split argsort
1 parent 8b83bfb commit a429352

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

torch_np/_detail/implementations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ def tensor_angle(z, deg=False):
8080
return result
8181

8282

83+
# ### sorting ###
84+
85+
def tensor_argsort(tensor, axis=-1, kind=None, order=None):
86+
if order is not None:
87+
raise NotImplementedError
88+
stable = True if kind == "stable" else False
89+
if axis is None:
90+
axis = -1
91+
return torch.argsort(tensor, stable=stable, dim=axis, descending=False)
92+
93+
8394
# ### splits ###
8495

8596

torch_np/_wrapper.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,6 @@ def triu(m, k=0):
705705
return m.triu(k)
706706

707707

708-
# YYY: pattern: return sequence
709708
def tril_indices(n, k=0, m=None):
710709
if m is None:
711710
m = n
@@ -1012,12 +1011,8 @@ def diff(a, n=1, axis=-1, prepend=NoValue, append=NoValue):
10121011

10131012
@asarray_replacer()
10141013
def argsort(a, axis=-1, kind=None, order=None):
1015-
if order is not None:
1016-
raise NotImplementedError
1017-
stable = True if kind == "stable" else False
1018-
if axis is None:
1019-
axis = -1
1020-
return torch.argsort(a, stable=stable, dim=axis, descending=False)
1014+
result = _impl.tensor_argsort(a, axis, kind, order)
1015+
return result
10211016

10221017

10231018
##### math functions

0 commit comments

Comments
 (0)