Skip to content

Commit 74523fa

Browse files
authored
Merge pull request #92 from Quansight-Labs/mangle_returns
MAINT: automagically wrap return values to ndarrays / lists / tuples
2 parents 906f005 + 91387f6 commit 74523fa

File tree

6 files changed

+130
-144
lines changed

6 files changed

+130
-144
lines changed

torch_np/_detail/implementations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,16 @@ def triu_indices_from(tensor, k):
102102
if tensor.ndim != 2:
103103
raise ValueError("input array must be 2-d")
104104
result = torch.triu_indices(tensor.shape[0], tensor.shape[1], offset=k)
105+
# unpack: numpy returns a 2-tuple of index arrays; torch returns a 2-row tensor
106+
result = tuple(result)
105107
return result
106108

107109

108110
def tril_indices_from(tensor, k=0):
109111
if tensor.ndim != 2:
110112
raise ValueError("input array must be 2-d")
111113
result = torch.tril_indices(tensor.shape[0], tensor.shape[1], offset=k)
114+
result = tuple(result)
112115
return result
113116

114117

torch_np/_funcs.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@
2020
@normalizer
2121
def nonzero(a: ArrayLike):
2222
result = a.nonzero(as_tuple=True)
23-
return _helpers.tuple_arrays_from(result)
23+
return result
2424

2525

2626
@normalizer
2727
def argwhere(a: ArrayLike):
2828
result = torch.argwhere(a)
29-
return _helpers.array_from(result)
29+
return result
3030

3131

3232
@normalizer
3333
def flatnonzero(a: ArrayLike):
3434
result = a.ravel().nonzero(as_tuple=True)[0]
35-
return _helpers.array_from(result)
35+
return result
3636

3737

3838
@normalizer
@@ -52,13 +52,13 @@ def clip(
5252
def repeat(a: ArrayLike, repeats: ArrayLike, axis=None):
5353
# XXX: scalar repeats; ArrayLikeOrScalar ?
5454
result = torch.repeat_interleave(a, repeats, axis)
55-
return _helpers.array_from(result)
55+
return result
5656

5757

5858
@normalizer
5959
def tile(A: ArrayLike, reps):
6060
result = _impl.tile(A, reps)
61-
return _helpers.array_from(result)
61+
return result
6262

6363

6464
# ### diag et al ###
@@ -67,7 +67,7 @@ def tile(A: ArrayLike, reps):
6767
@normalizer
6868
def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1):
6969
result = _impl.diagonal(a, offset, axis1, axis2)
70-
return _helpers.array_from(result)
70+
return result
7171

7272

7373
@normalizer
@@ -88,42 +88,42 @@ def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike
8888
if order != "C":
8989
raise NotImplementedError
9090
result = _impl.eye(N, M, k, dtype)
91-
return _helpers.array_from(result)
91+
return result
9292

9393

9494
@normalizer
9595
def identity(n, dtype: DTypeLike = None, *, like: SubokLike = None):
9696
result = torch.eye(n, dtype=dtype)
97-
return _helpers.array_from(result)
97+
return result
9898

9999

100100
@normalizer
101101
def diag(v: ArrayLike, k=0):
102102
result = torch.diag(v, k)
103-
return _helpers.array_from(result)
103+
return result
104104

105105

106106
@normalizer
107107
def diagflat(v: ArrayLike, k=0):
108108
result = torch.diagflat(v, k)
109-
return _helpers.array_from(result)
109+
return result
110110

111111

112112
def diag_indices(n, ndim=2):
113113
result = _impl.diag_indices(n, ndim)
114-
return _helpers.tuple_arrays_from(result)
114+
return result
115115

116116

117117
@normalizer
118118
def diag_indices_from(arr: ArrayLike):
119119
result = _impl.diag_indices_from(arr)
120-
return _helpers.tuple_arrays_from(result)
120+
return result
121121

122122

123123
@normalizer
124124
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
125125
result = _impl.fill_diagonal(a, val, wrap)
126-
return _helpers.array_from(result)
126+
return result
127127

128128

129129
@normalizer
@@ -144,21 +144,21 @@ def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
144144
@normalizer
145145
def sort(a: ArrayLike, axis=-1, kind=None, order=None):
146146
result = _impl.sort(a, axis, kind, order)
147-
return _helpers.array_from(result)
147+
return result
148148

149149

150150
@normalizer
151151
def argsort(a: ArrayLike, axis=-1, kind=None, order=None):
152152
result = _impl.argsort(a, axis, kind, order)
153-
return _helpers.array_from(result)
153+
return result
154154

155155

156156
@normalizer
157157
def searchsorted(
158158
a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None
159159
):
160160
result = torch.searchsorted(a, v, side=side, sorter=sorter)
161-
return _helpers.array_from(result)
161+
return result
162162

163163

164164
# ### swap/move/roll axis ###
@@ -167,19 +167,19 @@ def searchsorted(
167167
@normalizer
168168
def moveaxis(a: ArrayLike, source, destination):
169169
result = _impl.moveaxis(a, source, destination)
170-
return _helpers.array_from(result)
170+
return result
171171

172172

173173
@normalizer
174174
def swapaxes(a: ArrayLike, axis1, axis2):
175175
result = _impl.swapaxes(a, axis1, axis2)
176-
return _helpers.array_from(result)
176+
return result
177177

178178

179179
@normalizer
180180
def rollaxis(a: ArrayLike, axis, start=0):
181181
result = _impl.rollaxis(a, axis, start)
182-
return _helpers.array_from(result)
182+
return result
183183

184184

185185
# ### shape manipulations ###
@@ -188,32 +188,32 @@ def rollaxis(a: ArrayLike, axis, start=0):
188188
@normalizer
189189
def squeeze(a: ArrayLike, axis=None):
190190
result = _impl.squeeze(a, axis)
191-
return _helpers.array_from(result, a)
191+
return result
192192

193193

194194
@normalizer
195195
def reshape(a: ArrayLike, newshape, order="C"):
196196
result = _impl.reshape(a, newshape, order=order)
197-
return _helpers.array_from(result, a)
197+
return result
198198

199199

200200
@normalizer
201201
def transpose(a: ArrayLike, axes=None):
202202
result = _impl.transpose(a, axes)
203-
return _helpers.array_from(result, a)
203+
return result
204204

205205

206206
@normalizer
207207
def ravel(a: ArrayLike, order="C"):
208208
result = _impl.ravel(a)
209-
return _helpers.array_from(result, a)
209+
return result
210210

211211

212212
# leading underscore since arr.flatten exists but np.flatten does not
213213
@normalizer
214214
def _flatten(a: ArrayLike, order="C"):
215215
result = _impl._flatten(a)
216-
return _helpers.array_from(result, a)
216+
return result
217217

218218

219219
# ### Type/shape etc queries ###
@@ -222,13 +222,13 @@ def _flatten(a: ArrayLike, order="C"):
222222
@normalizer
223223
def real(a: ArrayLike):
224224
result = torch.real(a)
225-
return _helpers.array_from(result)
225+
return result
226226

227227

228228
@normalizer
229229
def imag(a: ArrayLike):
230230
result = _impl.imag(a)
231-
return _helpers.array_from(result)
231+
return result
232232

233233

234234
@normalizer
@@ -420,7 +420,7 @@ def any(
420420
@normalizer
421421
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False):
422422
result = _impl.count_nonzero(a, axis=axis, keepdims=keepdims)
423-
return _helpers.array_from(result)
423+
return result
424424

425425

426426
@normalizer

torch_np/_helpers.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,9 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
5555
out_tensor.copy_(result_tensor)
5656
return out_array
5757
else:
58-
return array_from(result_tensor)
58+
from ._ndarray import ndarray
5959

60-
61-
def array_from(tensor, base=None):
62-
from ._ndarray import ndarray
63-
64-
return ndarray(tensor)
65-
66-
67-
def tuple_arrays_from(result):
68-
from ._ndarray import asarray
69-
70-
return tuple(asarray(x) for x in result)
60+
return ndarray(result_tensor)
7161

7262

7363
# ### Various ways of converting array-likes to tensors ###
@@ -94,10 +84,3 @@ def ndarrays_to_tensors(*inputs):
9484
else:
9585
assert isinstance(inputs, tuple) # sanity check
9686
return ndarrays_to_tensors(inputs)
97-
98-
99-
def to_tensors(*inputs):
100-
"""Convert all array_likes from `inputs` to tensors."""
101-
from ._ndarray import asarray, ndarray
102-
103-
return tuple(asarray(value).tensor for value in inputs)

torch_np/_normalizations.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121

2222

2323
def normalize_array_like(x, name=None):
24-
(tensor,) = _helpers.to_tensors(x)
25-
return tensor
24+
from ._ndarray import asarray
25+
26+
return asarray(x).tensor
2627

2728

2829
def normalize_optional_array_like(x, name=None):
@@ -32,8 +33,7 @@ def normalize_optional_array_like(x, name=None):
3233

3334

3435
def normalize_seq_array_like(x, name=None):
35-
tensors = _helpers.to_tensors(*x)
36-
return tensors
36+
return tuple(normalize_array_like(value) for value in x)
3737

3838

3939
def normalize_dtype(dtype, name=None):
@@ -96,6 +96,28 @@ def maybe_normalize(arg, parm, return_on_failure=_sentinel):
9696
raise exc from None
9797

9898

99+
def wrap_tensors(result):
100+
from ._ndarray import ndarray
101+
102+
if isinstance(result, torch.Tensor):
103+
result = ndarray(result)
104+
elif isinstance(result, (tuple, list)):
105+
result = type(result)(
106+
ndarray(x) if isinstance(x, torch.Tensor) else x for x in result
107+
)
108+
109+
return result
110+
111+
112+
def array_or_scalar(values, py_type=float, return_scalar=False):
113+
if return_scalar:
114+
return py_type(values.item())
115+
else:
116+
from ._ndarray import ndarray
117+
118+
return ndarray(values)
119+
120+
99121
def normalizer(_func=None, *, return_on_failure=_sentinel):
100122
def normalizer_inner(func):
101123
@functools.wraps(func)
@@ -121,7 +143,9 @@ def wrapped(*args, **kwds):
121143
name: maybe_normalize(arg, params[name]) if name in params else arg
122144
for name, arg in kwds.items()
123145
}
124-
return func(*args, **kwds)
146+
result = func(*args, **kwds)
147+
result = wrap_tensors(result)
148+
return result
125149

126150
return wrapped
127151

0 commit comments

Comments
 (0)