Skip to content

Commit a126f71

Browse files
committed
MAINT: remove return annotation, rework out= arg handling
1 parent 3a3b100 commit a126f71

File tree

6 files changed

+84
-114
lines changed

6 files changed

+84
-114
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,7 @@
44

55
from . import _helpers
66
from ._detail import _binary_ufuncs
7-
from ._normalizations import (
8-
ArrayLike,
9-
DTypeLike,
10-
NDArray,
11-
OutArray,
12-
SubokLike,
13-
normalizer,
14-
)
7+
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
158

169
__all__ = [
1710
name
@@ -40,7 +33,7 @@ def wrapped(
4033
subok: SubokLike = False,
4134
signature=None,
4235
extobj=None,
43-
) -> OutArray:
36+
):
4437
tensors = _helpers.ufunc_preprocess(
4538
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
4639
)
@@ -51,7 +44,7 @@ def wrapped(
5144
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
5245

5346
result = torch_func(*tensors)
54-
return result, out
47+
return result
5548

5649
return wrapped
5750

@@ -75,7 +68,7 @@ def matmul(
7568
extobj=None,
7669
axes=None,
7770
axis=None,
78-
) -> OutArray:
71+
):
7972
tensors = _helpers.ufunc_preprocess(
8073
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
8174
)
@@ -84,7 +77,7 @@ def matmul(
8477

8578
# NB: do not broadcast input tensors against the out=... array
8679
result = _binary_ufuncs.matmul(*tensors)
87-
return result, out
80+
return result
8881

8982

9083
#

torch_np/_funcs.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
AxisLike,
1111
DTypeLike,
1212
NDArray,
13-
OutArray,
1413
SubokLike,
1514
normalizer,
1615
)
@@ -42,11 +41,11 @@ def clip(
4241
min: Optional[ArrayLike] = None,
4342
max: Optional[ArrayLike] = None,
4443
out: Optional[NDArray] = None,
45-
) -> OutArray:
44+
):
4645
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
4746
# one of them to be None. Follow the more lax version.
4847
result = _impl.clip(a, min, max)
49-
return result, out
48+
return result
5049

5150

5251
@normalizer
@@ -79,9 +78,9 @@ def trace(
7978
axis2=1,
8079
dtype: DTypeLike = None,
8180
out: Optional[NDArray] = None,
82-
) -> OutArray:
81+
):
8382
result = _impl.trace(a, offset, axis1, axis2, dtype)
84-
return result, out
83+
return result
8584

8685

8786
@normalizer
@@ -134,9 +133,9 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
134133

135134

136135
@normalizer
137-
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None) -> OutArray:
136+
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
138137
result = _impl.dot(a, b)
139-
return result, out
138+
return result
140139

141140

142141
# ### sort and partition ###
@@ -233,9 +232,9 @@ def imag(a: ArrayLike):
233232

234233

235234
@normalizer
236-
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None) -> OutArray:
235+
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
237236
result = _impl.round(a, decimals)
238-
return result, out
237+
return result
239238

240239

241240
around = round_
@@ -254,11 +253,11 @@ def sum(
254253
keepdims=NoValue,
255254
initial=NoValue,
256255
where=NoValue,
257-
) -> OutArray:
256+
):
258257
result = _impl.sum(
259258
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
260259
)
261-
return result, out
260+
return result
262261

263262

264263
@normalizer
@@ -270,11 +269,11 @@ def prod(
270269
keepdims=NoValue,
271270
initial=NoValue,
272271
where=NoValue,
273-
) -> OutArray:
272+
):
274273
result = _impl.prod(
275274
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
276275
)
277-
return result, out
276+
return result
278277

279278

280279
product = prod
@@ -289,9 +288,9 @@ def mean(
289288
keepdims=NoValue,
290289
*,
291290
where=NoValue,
292-
) -> OutArray:
291+
):
293292
result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims)
294-
return result, out
293+
return result
295294

296295

297296
@normalizer
@@ -304,11 +303,11 @@ def var(
304303
keepdims=NoValue,
305304
*,
306305
where=NoValue,
307-
) -> OutArray:
306+
):
308307
result = _impl.var(
309308
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
310309
)
311-
return result, out
310+
return result
312311

313312

314313
@normalizer
@@ -321,11 +320,11 @@ def std(
321320
keepdims=NoValue,
322321
*,
323322
where=NoValue,
324-
) -> OutArray:
323+
):
325324
result = _impl.std(
326325
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
327326
)
328-
return result, out
327+
return result
329328

330329

331330
@normalizer
@@ -335,9 +334,9 @@ def argmin(
335334
out: Optional[NDArray] = None,
336335
*,
337336
keepdims=NoValue,
338-
) -> OutArray:
337+
):
339338
result = _impl.argmin(a, axis=axis, keepdims=keepdims)
340-
return result, out
339+
return result
341340

342341

343342
@normalizer
@@ -347,9 +346,9 @@ def argmax(
347346
out: Optional[NDArray] = None,
348347
*,
349348
keepdims=NoValue,
350-
) -> OutArray:
349+
):
351350
result = _impl.argmax(a, axis=axis, keepdims=keepdims)
352-
return result, out
351+
return result
353352

354353

355354
@normalizer
@@ -360,9 +359,9 @@ def amax(
360359
keepdims=NoValue,
361360
initial=NoValue,
362361
where=NoValue,
363-
) -> OutArray:
362+
):
364363
result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
365-
return result, out
364+
return result
366365

367366

368367
max = amax
@@ -376,9 +375,9 @@ def amin(
376375
keepdims=NoValue,
377376
initial=NoValue,
378377
where=NoValue,
379-
) -> OutArray:
378+
):
380379
result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
381-
return result, out
380+
return result
382381

383382

384383
min = amin
@@ -387,9 +386,9 @@ def amin(
387386
@normalizer
388387
def ptp(
389388
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
390-
) -> OutArray:
389+
):
391390
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
392-
return result, out
391+
return result
393392

394393

395394
@normalizer
@@ -400,9 +399,9 @@ def all(
400399
keepdims=NoValue,
401400
*,
402401
where=NoValue,
403-
) -> OutArray:
402+
):
404403
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
405-
return result, out
404+
return result
406405

407406

408407
@normalizer
@@ -413,9 +412,9 @@ def any(
413412
keepdims=NoValue,
414413
*,
415414
where=NoValue,
416-
) -> OutArray:
415+
):
417416
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
418-
return result, out
417+
return result
419418

420419

421420
@normalizer
@@ -430,9 +429,9 @@ def cumsum(
430429
axis: AxisLike = None,
431430
dtype: DTypeLike = None,
432431
out: Optional[NDArray] = None,
433-
) -> OutArray:
432+
):
434433
result = _impl.cumsum(a, axis=axis, dtype=dtype)
435-
return result, out
434+
return result
436435

437436

438437
@normalizer
@@ -441,9 +440,9 @@ def cumprod(
441440
axis: AxisLike = None,
442441
dtype: DTypeLike = None,
443442
out: Optional[NDArray] = None,
444-
) -> OutArray:
443+
):
445444
result = _impl.cumprod(a, axis=axis, dtype=dtype)
446-
return result, out
445+
return result
447446

448447

449448
cumproduct = cumprod
@@ -460,7 +459,7 @@ def quantile(
460459
keepdims=False,
461460
*,
462461
interpolation=None,
463-
) -> OutArray:
462+
):
464463
result = _impl.quantile(
465464
a,
466465
q,
@@ -470,7 +469,7 @@ def quantile(
470469
keepdims=keepdims,
471470
interpolation=interpolation,
472471
)
473-
return result, out
472+
return result
474473

475474

476475
@normalizer(promote_scalar_result=True)
@@ -484,7 +483,7 @@ def percentile(
484483
keepdims=False,
485484
*,
486485
interpolation=None,
487-
) -> OutArray:
486+
):
488487
result = _impl.percentile(
489488
a,
490489
q,
@@ -494,7 +493,7 @@ def percentile(
494493
keepdims=keepdims,
495494
interpolation=interpolation,
496495
)
497-
return result, out
496+
return result
498497

499498

500499
def median(

0 commit comments

Comments
 (0)