Skip to content

Commit fb65599

Browse files
committed
MAINT: add the return annotation for return asarray(result)
1 parent 0210366 commit fb65599

File tree

4 files changed

+151
-144
lines changed

4 files changed

+151
-144
lines changed

torch_np/_funcs.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,14 @@ def nonzero(a: ArrayLike):
2525

2626

2727
@normalizer
28-
def argwhere(a: ArrayLike):
28+
def argwhere(a: ArrayLike) -> NDArray:
2929
result = torch.argwhere(a)
30-
return _helpers.array_from(result)
31-
30+
return result
3231

3332
@normalizer
34-
def flatnonzero(a: ArrayLike):
33+
def flatnonzero(a: ArrayLike) -> NDArray:
3534
result = a.ravel().nonzero(as_tuple=True)[0]
36-
return _helpers.array_from(result)
35+
return result
3736

3837

3938
@normalizer
@@ -50,25 +49,24 @@ def clip(
5049

5150

5251
@normalizer
53-
def repeat(a: ArrayLike, repeats: ArrayLike, axis=None):
54-
# XXX: scalar repeats; ArrayLikeOrScalar ?
52+
def repeat(a: ArrayLike, repeats: ArrayLike, axis=None) -> NDArray:
5553
result = torch.repeat_interleave(a, repeats, axis)
56-
return _helpers.array_from(result)
54+
return result
5755

5856

5957
@normalizer
60-
def tile(A: ArrayLike, reps):
58+
def tile(A: ArrayLike, reps) -> NDArray:
6159
result = _impl.tile(A, reps)
62-
return _helpers.array_from(result)
60+
return result
6361

6462

6563
# ### diag et al ###
6664

6765

6866
@normalizer
69-
def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1):
67+
def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1) -> NDArray:
7068
result = _impl.diagonal(a, offset, axis1, axis2)
71-
return _helpers.array_from(result)
69+
return result
7270

7371

7472
@normalizer
@@ -85,29 +83,29 @@ def trace(
8583

8684

8785
@normalizer
88-
def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike = None):
86+
def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike = None) -> NDArray:
8987
if order != "C":
9088
raise NotImplementedError
9189
result = _impl.eye(N, M, k, dtype)
92-
return _helpers.array_from(result)
90+
return result
9391

9492

9593
@normalizer
96-
def identity(n, dtype: DTypeLike = None, *, like: SubokLike = None):
94+
def identity(n, dtype: DTypeLike = None, *, like: SubokLike = None) -> NDArray:
9795
result = torch.eye(n, dtype=dtype)
98-
return _helpers.array_from(result)
96+
return result
9997

10098

10199
@normalizer
102-
def diag(v: ArrayLike, k=0):
100+
def diag(v: ArrayLike, k=0) -> NDArray:
103101
result = torch.diag(v, k)
104-
return _helpers.array_from(result)
102+
return result
105103

106104

107105
@normalizer
108-
def diagflat(v: ArrayLike, k=0):
106+
def diagflat(v: ArrayLike, k=0) -> NDArray:
109107
result = torch.diagflat(v, k)
110-
return _helpers.array_from(result)
108+
return result
111109

112110

113111
def diag_indices(n, ndim=2):
@@ -122,9 +120,9 @@ def diag_indices_from(arr: ArrayLike):
122120

123121

124122
@normalizer
125-
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
123+
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False) -> NDArray:
126124
result = _impl.fill_diagonal(a, val, wrap)
127-
return _helpers.array_from(result)
125+
return result
128126

129127

130128
@normalizer
@@ -143,93 +141,93 @@ def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
143141

144142

145143
@normalizer
146-
def sort(a: ArrayLike, axis=-1, kind=None, order=None):
144+
def sort(a: ArrayLike, axis=-1, kind=None, order=None) -> NDArray:
147145
result = _impl.sort(a, axis, kind, order)
148-
return _helpers.array_from(result)
146+
return result
149147

150148

151149
@normalizer
152-
def argsort(a: ArrayLike, axis=-1, kind=None, order=None):
150+
def argsort(a: ArrayLike, axis=-1, kind=None, order=None) -> NDArray:
153151
result = _impl.argsort(a, axis, kind, order)
154-
return _helpers.array_from(result)
152+
return result
155153

156154

157155
@normalizer
158156
def searchsorted(
159157
a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None
160-
):
158+
) -> NDArray:
161159
result = torch.searchsorted(a, v, side=side, sorter=sorter)
162-
return _helpers.array_from(result)
160+
return result
163161

164162

165163
# ### swap/move/roll axis ###
166164

167165

168166
@normalizer
169-
def moveaxis(a: ArrayLike, source, destination):
167+
def moveaxis(a: ArrayLike, source, destination) -> NDArray:
170168
result = _impl.moveaxis(a, source, destination)
171-
return _helpers.array_from(result)
169+
return result
172170

173171

174172
@normalizer
175-
def swapaxes(a: ArrayLike, axis1, axis2):
173+
def swapaxes(a: ArrayLike, axis1, axis2) -> NDArray:
176174
result = _impl.swapaxes(a, axis1, axis2)
177-
return _helpers.array_from(result)
175+
return result
178176

179177

180178
@normalizer
181-
def rollaxis(a: ArrayLike, axis, start=0):
179+
def rollaxis(a: ArrayLike, axis, start=0) -> NDArray:
182180
result = _impl.rollaxis(a, axis, start)
183-
return _helpers.array_from(result)
181+
return result
184182

185183

186184
# ### shape manipulations ###
187185

188186

189187
@normalizer
190-
def squeeze(a: ArrayLike, axis=None):
188+
def squeeze(a: ArrayLike, axis=None) -> NDArray:
191189
result = _impl.squeeze(a, axis)
192-
return _helpers.array_from(result, a)
190+
return result
193191

194192

195193
@normalizer
196-
def reshape(a: ArrayLike, newshape, order="C"):
194+
def reshape(a: ArrayLike, newshape, order="C") -> NDArray:
197195
result = _impl.reshape(a, newshape, order=order)
198-
return _helpers.array_from(result, a)
196+
return result
199197

200198

201199
@normalizer
202-
def transpose(a: ArrayLike, axes=None):
200+
def transpose(a: ArrayLike, axes=None) -> NDArray:
203201
result = _impl.transpose(a, axes)
204-
return _helpers.array_from(result, a)
202+
return result
205203

206204

207205
@normalizer
208-
def ravel(a: ArrayLike, order="C"):
206+
def ravel(a: ArrayLike, order="C") -> NDArray:
209207
result = _impl.ravel(a)
210-
return _helpers.array_from(result, a)
208+
return result
211209

212210

213211
# leading underscore since arr.flatten exists but np.flatten does not
214212
@normalizer
215-
def _flatten(a: ArrayLike, order="C"):
213+
def _flatten(a: ArrayLike, order="C") -> NDArray:
216214
result = _impl._flatten(a)
217-
return _helpers.array_from(result, a)
215+
return result
218216

219217

220218
# ### Type/shape etc queries ###
221219

222220

223221
@normalizer
224-
def real(a: ArrayLike):
222+
def real(a: ArrayLike) -> NDArray:
225223
result = torch.real(a)
226-
return _helpers.array_from(result)
224+
return result
227225

228226

229227
@normalizer
230-
def imag(a: ArrayLike):
228+
def imag(a: ArrayLike) -> NDArray:
231229
result = _impl.imag(a)
232-
return _helpers.array_from(result)
230+
return result
233231

234232

235233
@normalizer
@@ -419,9 +417,9 @@ def any(
419417

420418

421419
@normalizer
422-
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False):
420+
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False) -> NDArray:
423421
result = _impl.count_nonzero(a, axis=axis, keepdims=keepdims)
424-
return _helpers.array_from(result)
422+
return result
425423

426424

427425
@normalizer

torch_np/_normalizations.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,15 @@ def wrapped(*args, **kwds):
152152
)
153153
# finally, pass normalized arguments through
154154
result = func(*ba.args, **ba.kwargs)
155-
return result
155+
156+
# handle returns
157+
return_annotation = sig.return_annotation
158+
if return_annotation == NDArray:
159+
return _helpers.array_from(result)
160+
elif return_annotation == inspect._empty:
161+
return result
162+
else:
163+
raise ValueError(f"Unknown return annotation {return_annotation}")
156164

157165
return wrapped
158166

0 commit comments

Comments
 (0)