Skip to content

Commit c730cb7

Browse files
committed
MAINT: move ndarray methods to free functions
1 parent d70cab6 commit c730cb7

File tree

5 files changed

+155
-148
lines changed

5 files changed

+155
-148
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ._detail._index_tricks import *
55
from ._detail._util import AxisError, UFuncTypeError
66
from ._dtypes import *
7+
from ._funcs import *
78
from ._getlimits import errstate, finfo, iinfo
89
from ._ndarray import array, asarray, can_cast, ndarray, newaxis, result_type
910
from ._unary_ufuncs import *

torch_np/_funcs.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import torch
2+
3+
from . import _decorators, _helpers
4+
from ._detail import _util
5+
from ._detail import implementations as _impl
6+
7+
8+
def nonzero(a):
9+
(tensor,) = _helpers.to_tensors(a)
10+
result = tensor.nonzero(as_tuple=True)
11+
return _helpers.tuple_arrays_from(result)
12+
13+
14+
def argwhere(a):
15+
(tensor,) = _helpers.to_tensors(a)
16+
result = torch.argwhere(tensor)
17+
return _helpers.array_from(result)
18+
19+
20+
def clip(a, min=None, max=None, out=None):
21+
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
22+
# one of them to be None. Follow the more lax version.
23+
# Also min/max as arg names: follow numpy naming.
24+
tensor, t_min, t_max = _helpers.to_tensors_or_none(a, min, max)
25+
result = _impl.clip(tensor, t_min, t_max)
26+
return _helpers.result_or_out(result, out)
27+
28+
29+
def repeat(a, repeats, axis=None):
30+
tensor, t_repeats = _helpers.to_tensors(a, repeats) # XXX: scalar repeats
31+
result = torch.repeat_interleave(tensor, t_repeats, axis)
32+
return _helpers.array_from(result)
33+
34+
35+
# ### diag et al ###
36+
37+
38+
def diagonal(a, offset=0, axis1=0, axis2=1):
39+
(tensor,) = _helpers.to_tensors(a)
40+
result = _impl.diagonal(tensor, offset, axis1, axis2)
41+
return _helpers.array_from(result)
42+
43+
44+
@_decorators.dtype_to_torch
45+
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
46+
(tensor,) = _helpers.to_tensors(a)
47+
result = _impl.trace(tensor, offset, axis1, axis2, dtype)
48+
return _helpers.result_or_out(result, out)
49+
50+
51+
@_decorators.dtype_to_torch
52+
def eye(N, M=None, k=0, dtype=float, order="C", *, like=None):
53+
_util.subok_not_ok(like)
54+
if order != "C":
55+
raise NotImplementedError
56+
result = _impl.eye(N, M, k, dtype)
57+
return _helpers.array_from(result)
58+
59+
60+
@_decorators.dtype_to_torch
61+
def identity(n, dtype=None, *, like=None):
62+
_util.subok_not_ok(like)
63+
result = torch.eye(n, dtype=dtype)
64+
return _helpers.array_from(result)
65+
66+
67+
def diag(v, k=0):
68+
(tensor,) = _helpers.to_tensors(v)
69+
result = torch.diag(tensor, k)
70+
return _helpers.array_from(result)
71+
72+
73+
def diagflat(v, k=0):
74+
(tensor,) = _helpers.to_tensors(v)
75+
result = torch.diagflat(tensor, k)
76+
return _helpers.array_from(result)
77+
78+
79+
def diag_indices(n, ndim=2):
80+
result = _impl.diag_indices(n, ndim)
81+
return _helpers.tuple_arrays_from(result)
82+
83+
84+
def diag_indices_from(arr):
85+
(tensor,) = _helpers.to_tensors(arr)
86+
result = _impl.diag_indices_from(tensor)
87+
return _helpers.tuple_arrays_from(result)
88+
89+
90+
def fill_diagonal(a, val, wrap=False):
91+
tensor, t_val = _helpers.to_tensors(a, val)
92+
result = _impl.fill_diagonal(tensor, t_val, wrap)
93+
return _helpers.array_from(result)
94+
95+
96+
# ### sorting ###
97+
98+
# ### sort and partition ###
99+
100+
101+
def sort(a, axis=-1, kind=None, order=None):
102+
(tensor,) = _helpers.to_tensors(a)
103+
result = _impl.sort(tensor, axis, kind, order)
104+
return _helpers.array_from(result)
105+
106+
107+
def argsort(a, axis=-1, kind=None, order=None):
108+
(tensor,) = _helpers.to_tensors(a)
109+
result = _impl.argsort(tensor, axis, kind, order)
110+
return _helpers.array_from(result)
111+
112+
113+
def searchsorted(a, v, side="left", sorter=None):
114+
a_t, v_t, sorter_t = _helpers.to_tensors_or_none(a, v, sorter)
115+
result = torch.searchsorted(a_t, v_t, side=side, sorter=sorter_t)
116+
return _helpers.array_from(result)

torch_np/_helpers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def cast_and_broadcast(tensors, out, casting):
4141
return tuple(tensors)
4242

4343

44+
# ### Return helpers: wrap a single tensor, a tuple of tensors, out= etc ###
45+
46+
4447
def result_or_out(result_tensor, out_array=None, promote_scalar=False):
4548
"""A helper for returns with out= argument.
4649
@@ -70,6 +73,21 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
7073
return asarray(result_tensor)
7174

7275

76+
def array_from(tensor):
77+
from ._ndarray import ndarray
78+
79+
return ndarray._from_tensor_and_base(tensor, None) # XXX: nuke .base
80+
81+
82+
def tuple_arrays_from(result):
83+
from ._ndarray import asarray
84+
85+
return tuple(asarray(x) for x in result)
86+
87+
88+
# ### Various ways of converting array-likes to tensors ###
89+
90+
7391
def ndarrays_to_tensors(*inputs):
7492
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
7593
from ._ndarray import asarray, ndarray

torch_np/_ndarray.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from . import _binary_ufuncs, _dtypes, _helpers, _unary_ufuncs
6+
from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs
77
from ._decorators import (
88
NoValue,
99
axis_keepdims_wrapper,
@@ -374,20 +374,24 @@ def flatten(self, order="C"):
374374
result = self._tensor.flatten()
375375
return asarray(result)
376376

377-
def nonzero(self):
378-
tensor = self._tensor
379-
return tuple(asarray(_) for _ in tensor.nonzero(as_tuple=True))
377+
nonzero = _funcs.nonzero
378+
clip = _funcs.clip
379+
repeat = _funcs.repeat
380380

381-
def clip(self, min=None, max=None, out=None):
382-
tensor, t_min, t_max = _helpers.to_tensors_or_none(self, min, max)
383-
result = _impl.clip(tensor, t_min, t_max)
384-
return _helpers.result_or_out(result, out)
381+
diagonal = _funcs.diagonal
382+
trace = _funcs.trace
385383

386-
def repeat(self, repeats, axis=None):
387-
t_repeats = asarray(repeats).get() # XXX: scalar repeats
388-
tensor = self._tensor
389-
result = torch.repeat_interleave(tensor, t_repeats, axis)
390-
return asarray(result)
384+
### sorting ###
385+
386+
def sort(self, axis=-1, kind=None, order=None):
387+
# ndarray.sort works in-place
388+
result = _impl.sort(self._tensor, axis, kind, order)
389+
self._tensor = result
390+
391+
argsort = _funcs.argsort
392+
searchsorted = _funcs.searchsorted
393+
394+
### reductions ###
391395

392396
argmin = emulate_out_arg(axis_keepdims_wrapper(_reductions.argmin))
393397
argmax = emulate_out_arg(axis_keepdims_wrapper(_reductions.argmax))
@@ -411,16 +415,6 @@ def repeat(self, repeats, axis=None):
411415
axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumsum))
412416
)
413417

414-
def diagonal(self, offset=0, axis1=0, axis2=1):
415-
result = _impl.diagonal(self._tensor, offset, axis1, axis2)
416-
return asarray(result)
417-
418-
@dtype_to_torch
419-
def trace(self, offset=0, axis1=0, axis2=1, dtype=None, out=None):
420-
tensor = self._tensor
421-
result = _impl.trace(tensor, offset, axis1, axis2, dtype)
422-
return _helpers.result_or_out(result, out)
423-
424418
### indexing ###
425419
@staticmethod
426420
def _upcast_int_indices(index):
@@ -442,22 +436,6 @@ def __setitem__(self, index, value):
442436
value = _helpers.ndarrays_to_tensors(value)
443437
return self._tensor.__setitem__(index, value)
444438

445-
### sorting ###
446-
447-
def sort(self, axis=-1, kind=None, order=None):
448-
# ndarray.sort works in-place
449-
result = _impl.sort(self._tensor, axis, kind, order)
450-
self._tensor = result
451-
452-
def argsort(self, axis=-1, kind=None, order=None):
453-
result = _impl.argsort(self._tensor, axis, kind, order)
454-
return asarray(result)
455-
456-
def searchsorted(self, v, side="left", sorter=None):
457-
v_t, sorter_t = _helpers.to_tensors_or_none(v, sorter)
458-
result = torch.searchsorted(self._tensor, v_t, side=side, sorter=sorter_t)
459-
return asarray(result)
460-
461439

462440
# This is the ideally the only place which talks to ndarray directly.
463441
# The rest goes through asarray (preferred) or array.

0 commit comments

Comments
 (0)