Skip to content

Commit 8b75e3f

Browse files
committed
MAINT: use copyto in ndarray.sort, nuke _ufuncs_impl._sort
1 parent 45d95e1 commit 8b75e3f

File tree

3 files changed

+5
-10
lines changed

3 files changed

+5
-10
lines changed

torch_np/_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from . import _funcs_impl
44
from ._normalizations import normalizer
55

6-
# _funcs_imple.py contains functions which mimic NumPy's eponimous equivalents,
6+
# _funcs_impl.py contains functions which mimic NumPy's eponimous equivalents,
77
# and consume/return PyTorch tensors/dtypes.
88
# They are also type annotated.
99
# Pull these functions from _funcs_impl and decorate them with @normalizer, which

torch_np/_funcs_impl.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,15 +1237,10 @@ def _sort_helper(tensor, axis, kind, order):
12371237
return tensor, axis, stable
12381238

12391239

1240-
def _sort(tensor, axis, kind, order):
1241-
# pure torch implementation, used below and in ndarray.sort
1242-
tensor, axis, stable = _sort_helper(tensor, axis, kind, order)
1243-
result = torch.sort(tensor, dim=axis, stable=stable)
1244-
return result.values
1245-
1246-
12471240
def sort(a: ArrayLike, axis=-1, kind=None, order=None):
1248-
return _sort(a, axis, kind, order)
1241+
a, axis, stable = _sort_helper(a, axis, kind, order)
1242+
result = torch.sort(a, dim=axis, stable=stable)
1243+
return result.values
12491244

12501245

12511246
def argsort(a: ArrayLike, axis=-1, kind=None, order=None):

torch_np/_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def resize(self, *new_shape, refcheck=False):
376376

377377
def sort(self, axis=-1, kind=None, order=None):
378378
# ndarray.sort works in-place
379-
self.tensor.copy_(_funcs_impl._sort(self.tensor, axis, kind, order))
379+
_funcs.copyto(self, _funcs.sort(self, axis, kind, order))
380380

381381
argsort = _funcs.argsort
382382
searchsorted = _funcs.searchsorted

0 commit comments

Comments
 (0)