Skip to content

Commit f667f8c

Browse files
authored
Merge pull request #69 from Quansight-Labs/free_funcs
move ndarray methods to free functions
2 parents 9df7b54 + bb75c17 commit f667f8c

File tree

8 files changed

+290
-256
lines changed

8 files changed

+290
-256
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ._detail._index_tricks import *
55
from ._detail._util import AxisError, UFuncTypeError
66
from ._dtypes import *
7+
from ._funcs import *
78
from ._getlimits import errstate, finfo, iinfo
89
from ._ndarray import array, asarray, can_cast, ndarray, newaxis, result_type
910
from ._unary_ufuncs import *

torch_np/_detail/implementations.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22

3-
from .. import _helpers
43
from . import _dtypes_impl, _util
54

65
# ### equality, equivalence, allclose ###
@@ -581,21 +580,18 @@ def squeeze(tensor, axis=None):
581580
return result
582581

583582

584-
def reshape(tensor, *shape, order="C"):
583+
def reshape(tensor, shape, order="C"):
585584
if order != "C":
586585
raise NotImplementedError
587-
newshape = shape[0] if len(shape) == 1 else shape
588-
# convert any tnp.ndarray inputs into tensors before passing to torch.Tensor.reshape
589-
t_newshape = _helpers.ndarrays_to_tensors(newshape)
590586
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
591-
result = tensor.reshape(t_newshape)
587+
newshape = shape[0] if len(shape) == 1 else shape
588+
result = tensor.reshape(newshape)
592589
return result
593590

594591

595-
def transpose(tensor, *axes):
596-
# numpy allows both .reshape(sh) and .reshape(*sh)
597-
axes = axes[0] if len(axes) == 1 else axes
598-
if axes == () or axes is None:
592+
def transpose(tensor, axes=None):
593+
# numpy allows both .tranpose(sh) and .transpose(*sh)
594+
if axes in [(), None, (None,)]:
599595
axes = tuple(range(tensor.ndim))[::-1]
600596
try:
601597
result = tensor.permute(axes)
@@ -604,6 +600,32 @@ def transpose(tensor, *axes):
604600
return result
605601

606602

603+
def ravel(tensor, order="C"):
604+
if order != "C":
605+
raise NotImplementedError
606+
result = tensor.ravel()
607+
return result
608+
609+
610+
# leading underscore since arr.flatten exists but np.flatten does not
611+
def _flatten(tensor, order="C"):
612+
if order != "C":
613+
raise NotImplementedError
614+
# return a copy
615+
result = tensor.flatten()
616+
return result
617+
618+
619+
# ### swap/move/roll axis ###
620+
621+
622+
def moveaxis(tensor, source, destination):
623+
source = _util.normalize_axis_tuple(source, tensor.ndim, "source")
624+
destination = _util.normalize_axis_tuple(destination, tensor.ndim, "destination")
625+
result = torch.moveaxis(tensor, source, destination)
626+
return result
627+
628+
607629
# ### Numeric ###
608630

609631

@@ -622,6 +644,14 @@ def round(tensor, decimals=0):
622644
return result
623645

624646

647+
def imag(tensor):
648+
if tensor.is_complex():
649+
result = tensor.imag
650+
else:
651+
result = torch.zeros_like(tensor)
652+
return result
653+
654+
625655
# ### put/take along axis ###
626656

627657

torch_np/_funcs.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import torch
2+
3+
from . import _decorators, _helpers
4+
from ._detail import _flips, _util
5+
from ._detail import implementations as _impl
6+
7+
8+
def nonzero(a):
9+
(tensor,) = _helpers.to_tensors(a)
10+
result = tensor.nonzero(as_tuple=True)
11+
return _helpers.tuple_arrays_from(result)
12+
13+
14+
def argwhere(a):
15+
(tensor,) = _helpers.to_tensors(a)
16+
result = torch.argwhere(tensor)
17+
return _helpers.array_from(result)
18+
19+
20+
def clip(a, min=None, max=None, out=None):
21+
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
22+
# one of them to be None. Follow the more lax version.
23+
# Also min/max as arg names: follow numpy naming.
24+
tensor, t_min, t_max = _helpers.to_tensors_or_none(a, min, max)
25+
result = _impl.clip(tensor, t_min, t_max)
26+
return _helpers.result_or_out(result, out)
27+
28+
29+
def repeat(a, repeats, axis=None):
30+
tensor, t_repeats = _helpers.to_tensors(a, repeats) # XXX: scalar repeats
31+
result = torch.repeat_interleave(tensor, t_repeats, axis)
32+
return _helpers.array_from(result)
33+
34+
35+
# ### diag et al ###
36+
37+
38+
def diagonal(a, offset=0, axis1=0, axis2=1):
39+
(tensor,) = _helpers.to_tensors(a)
40+
result = _impl.diagonal(tensor, offset, axis1, axis2)
41+
return _helpers.array_from(result)
42+
43+
44+
@_decorators.dtype_to_torch
45+
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
46+
(tensor,) = _helpers.to_tensors(a)
47+
result = _impl.trace(tensor, offset, axis1, axis2, dtype)
48+
return _helpers.result_or_out(result, out)
49+
50+
51+
@_decorators.dtype_to_torch
52+
def eye(N, M=None, k=0, dtype=float, order="C", *, like=None):
53+
_util.subok_not_ok(like)
54+
if order != "C":
55+
raise NotImplementedError
56+
result = _impl.eye(N, M, k, dtype)
57+
return _helpers.array_from(result)
58+
59+
60+
@_decorators.dtype_to_torch
61+
def identity(n, dtype=None, *, like=None):
62+
_util.subok_not_ok(like)
63+
result = torch.eye(n, dtype=dtype)
64+
return _helpers.array_from(result)
65+
66+
67+
def diag(v, k=0):
68+
(tensor,) = _helpers.to_tensors(v)
69+
result = torch.diag(tensor, k)
70+
return _helpers.array_from(result)
71+
72+
73+
def diagflat(v, k=0):
74+
(tensor,) = _helpers.to_tensors(v)
75+
result = torch.diagflat(tensor, k)
76+
return _helpers.array_from(result)
77+
78+
79+
def diag_indices(n, ndim=2):
80+
result = _impl.diag_indices(n, ndim)
81+
return _helpers.tuple_arrays_from(result)
82+
83+
84+
def diag_indices_from(arr):
85+
(tensor,) = _helpers.to_tensors(arr)
86+
result = _impl.diag_indices_from(tensor)
87+
return _helpers.tuple_arrays_from(result)
88+
89+
90+
def fill_diagonal(a, val, wrap=False):
91+
tensor, t_val = _helpers.to_tensors(a, val)
92+
result = _impl.fill_diagonal(tensor, t_val, wrap)
93+
return _helpers.array_from(result)
94+
95+
96+
# ### sorting ###
97+
98+
# ### sort and partition ###
99+
100+
101+
def sort(a, axis=-1, kind=None, order=None):
102+
(tensor,) = _helpers.to_tensors(a)
103+
result = _impl.sort(tensor, axis, kind, order)
104+
return _helpers.array_from(result)
105+
106+
107+
def argsort(a, axis=-1, kind=None, order=None):
108+
(tensor,) = _helpers.to_tensors(a)
109+
result = _impl.argsort(tensor, axis, kind, order)
110+
return _helpers.array_from(result)
111+
112+
113+
def searchsorted(a, v, side="left", sorter=None):
114+
a_t, v_t, sorter_t = _helpers.to_tensors_or_none(a, v, sorter)
115+
result = torch.searchsorted(a_t, v_t, side=side, sorter=sorter_t)
116+
return _helpers.array_from(result)
117+
118+
119+
# ### swap/move/roll axis ###
120+
121+
122+
def moveaxis(a, source, destination):
123+
(tensor,) = _helpers.to_tensors(a)
124+
result = _impl.moveaxis(tensor, source, destination)
125+
return _helpers.array_from(result)
126+
127+
128+
def swapaxes(a, axis1, axis2):
129+
(tensor,) = _helpers.to_tensors(a)
130+
result = _flips.swapaxes(tensor, axis1, axis2)
131+
return _helpers.array_from(result)
132+
133+
134+
def rollaxis(a, axis, start=0):
135+
(tensor,) = _helpers.to_tensors(a)
136+
result = _flips.rollaxis(a, axis, start)
137+
return _helpers.array_from(result)
138+
139+
140+
# ### shape manipulations ###
141+
142+
143+
def squeeze(a, axis=None):
144+
(tensor,) = _helpers.to_tensors(a)
145+
result = _impl.squeeze(tensor, axis)
146+
return _helpers.array_from(result, a)
147+
148+
149+
def reshape(a, newshape, order="C"):
150+
(tensor,) = _helpers.to_tensors(a)
151+
result = _impl.reshape(tensor, newshape, order=order)
152+
return _helpers.array_from(result, a)
153+
154+
155+
def transpose(a, axes=None):
156+
(tensor,) = _helpers.to_tensors(a)
157+
result = _impl.transpose(tensor, axes)
158+
return _helpers.array_from(result, a)
159+
160+
161+
def ravel(a, order="C"):
162+
(tensor,) = _helpers.to_tensors(a)
163+
result = _impl.ravel(tensor)
164+
return _helpers.array_from(result, a)
165+
166+
167+
# leading underscore since arr.flatten exists but np.flatten does not
168+
def _flatten(a, order="C"):
169+
(tensor,) = _helpers.to_tensors(a)
170+
result = _impl._flatten(tensor)
171+
return _helpers.array_from(result, a)
172+
173+
174+
# ### Type/shape etc queries ###
175+
176+
177+
def real(a):
178+
(tensor,) = _helpers.to_tensors(a)
179+
result = torch.real(tensor)
180+
return _helpers.array_from(result)
181+
182+
183+
def imag(a):
184+
(tensor,) = _helpers.to_tensors(a)
185+
result = _impl.imag(tensor)
186+
return _helpers.array_from(result)
187+
188+
189+
def round_(a, decimals=0, out=None):
190+
(tensor,) = _helpers.to_tensors(a)
191+
result = _impl.round(tensor, decimals)
192+
return _helpers.result_or_out(result, out)
193+
194+
195+
around = round_
196+
round = round_

torch_np/_helpers.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def cast_and_broadcast(tensors, out, casting):
4141
return tuple(tensors)
4242

4343

44+
# ### Return helpers: wrap a single tensor, a tuple of tensors, out= etc ###
45+
46+
4447
def result_or_out(result_tensor, out_array=None, promote_scalar=False):
4548
"""A helper for returns with out= argument.
4649
@@ -70,6 +73,22 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
7073
return asarray(result_tensor)
7174

7275

76+
def array_from(tensor, base=None):
77+
from ._ndarray import ndarray
78+
79+
base = base if isinstance(base, ndarray) else None
80+
return ndarray._from_tensor_and_base(tensor, base) # XXX: nuke .base
81+
82+
83+
def tuple_arrays_from(result):
84+
from ._ndarray import asarray
85+
86+
return tuple(asarray(x) for x in result)
87+
88+
89+
# ### Various ways of converting array-likes to tensors ###
90+
91+
7392
def ndarrays_to_tensors(*inputs):
7493
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
7594
from ._ndarray import asarray, ndarray
@@ -105,11 +124,3 @@ def to_tensors_or_none(*inputs):
105124
from ._ndarray import asarray, ndarray
106125

107126
return tuple(None if value is None else asarray(value).get() for value in inputs)
108-
109-
110-
def _outer(x, y):
111-
from ._ndarray import asarray
112-
113-
x_tensor, y_tensor = to_tensors(x, y)
114-
result = torch.outer(x_tensor, y_tensor)
115-
return asarray(result)

0 commit comments

Comments
 (0)