Skip to content

Commit 16c5aed

Browse files
committed
MAINT: annotate out as NDArray, remove scattered isinstance checks
1 parent 10672bb commit 16c5aed

File tree

6 files changed

+53
-45
lines changed

6 files changed

+53
-45
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from . import _helpers
22
from ._detail import _binary_ufuncs
3-
from ._normalizations import ArrayLike, DTypeLike, SubokLike, normalizer
3+
from ._normalizations import ArrayLike, DTypeLike, SubokLike, NDArray, normalizer
4+
from typing import Optional
5+
46

57
__all__ = [
68
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
@@ -18,7 +20,7 @@ def wrapped(
1820
x1: ArrayLike,
1921
x2: ArrayLike,
2022
/,
21-
out=None,
23+
out: Optional[NDArray] = None,
2224
*,
2325
where=True,
2426
casting="same_kind",

torch_np/_funcs.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
ArrayLike,
1010
AxisLike,
1111
DTypeLike,
12+
NDArray,
1213
SubokLike,
1314
UnpackedSeqArrayLike,
1415
normalizer,
1516
)
17+
from typing import Optional
1618

1719

1820
@normalizer
@@ -32,7 +34,7 @@ def clip(
3234
a: ArrayLike,
3335
min: Optional[ArrayLike] = None,
3436
max: Optional[ArrayLike] = None,
35-
out=None,
37+
out: Optional[NDArray] = None,
3638
):
3739
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
3840
# one of them to be None. Follow the more lax version.
@@ -57,7 +59,7 @@ def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1):
5759

5860

5961
@normalizer
60-
def trace(a: ArrayLike, offset=0, axis1=0, axis2=1, dtype: DTypeLike = None, out=None):
62+
def trace(a: ArrayLike, offset=0, axis1=0, axis2=1, dtype: DTypeLike = None, out: Optional[NDArray] = None):
6163
result = _impl.trace(a, offset, axis1, axis2, dtype)
6264
return _helpers.result_or_out(result, out)
6365

@@ -112,7 +114,7 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
112114

113115

114116
@normalizer
115-
def dot(a: ArrayLike, b: ArrayLike, out=None):
117+
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
116118
result = _impl.dot(a, b)
117119
return _helpers.result_or_out(result, out)
118120

@@ -211,7 +213,7 @@ def imag(a: ArrayLike):
211213

212214

213215
@normalizer
214-
def round_(a: ArrayLike, decimals=0, out=None):
216+
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray]=None):
215217
result = _impl.round(a, decimals)
216218
return _helpers.result_or_out(result, out)
217219

@@ -231,7 +233,7 @@ def sum(
231233
a: ArrayLike,
232234
axis: AxisLike = None,
233235
dtype: DTypeLike = None,
234-
out=None,
236+
out: Optional[NDArray]=None,
235237
keepdims=NoValue,
236238
initial=NoValue,
237239
where=NoValue,
@@ -247,7 +249,7 @@ def prod(
247249
a: ArrayLike,
248250
axis: AxisLike = None,
249251
dtype: DTypeLike = None,
250-
out=None,
252+
out: Optional[NDArray]=None,
251253
keepdims=NoValue,
252254
initial=NoValue,
253255
where=NoValue,
@@ -266,7 +268,7 @@ def mean(
266268
a: ArrayLike,
267269
axis: AxisLike = None,
268270
dtype: DTypeLike = None,
269-
out=None,
271+
out: Optional[NDArray]=None,
270272
keepdims=NoValue,
271273
*,
272274
where=NoValue,
@@ -282,7 +284,7 @@ def var(
282284
a: ArrayLike,
283285
axis: AxisLike = None,
284286
dtype: DTypeLike = None,
285-
out=None,
287+
out: Optional[NDArray]=None,
286288
ddof=0,
287289
keepdims=NoValue,
288290
*,
@@ -299,7 +301,7 @@ def std(
299301
a: ArrayLike,
300302
axis: AxisLike = None,
301303
dtype: DTypeLike = None,
302-
out=None,
304+
out: Optional[NDArray]=None,
303305
ddof=0,
304306
keepdims=NoValue,
305307
*,
@@ -312,13 +314,13 @@ def std(
312314

313315

314316
@normalizer
315-
def argmin(a: ArrayLike, axis: AxisLike = None, out=None, *, keepdims=NoValue):
317+
def argmin(a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, *, keepdims=NoValue):
316318
result = _reductions.argmin(a, axis=axis, keepdims=keepdims)
317319
return _helpers.result_or_out(result, out)
318320

319321

320322
@normalizer
321-
def argmax(a: ArrayLike, axis: AxisLike = None, out=None, *, keepdims=NoValue):
323+
def argmax(a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, *, keepdims=NoValue):
322324
result = _reductions.argmax(a, axis=axis, keepdims=keepdims)
323325
return _helpers.result_or_out(result, out)
324326

@@ -327,7 +329,7 @@ def argmax(a: ArrayLike, axis: AxisLike = None, out=None, *, keepdims=NoValue):
327329
def amax(
328330
a: ArrayLike,
329331
axis: AxisLike = None,
330-
out=None,
332+
out: Optional[NDArray] = None,
331333
keepdims=NoValue,
332334
initial=NoValue,
333335
where=NoValue,
@@ -345,7 +347,7 @@ def amax(
345347
def amin(
346348
a: ArrayLike,
347349
axis: AxisLike = None,
348-
out=None,
350+
out: Optional[NDArray] = None,
349351
keepdims=NoValue,
350352
initial=NoValue,
351353
where=NoValue,
@@ -360,22 +362,22 @@ def amin(
360362

361363

362364
@normalizer
363-
def ptp(a: ArrayLike, axis: AxisLike = None, out=None, keepdims=NoValue):
365+
def ptp(a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue):
364366
result = _reductions.ptp(a, axis=axis, keepdims=keepdims)
365367
return _helpers.result_or_out(result, out)
366368

367369

368370
@normalizer
369371
def all(
370-
a: ArrayLike, axis: AxisLike = None, out=None, keepdims=NoValue, *, where=NoValue
372+
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue, *, where=NoValue
371373
):
372374
result = _reductions.all(a, axis=axis, where=where, keepdims=keepdims)
373375
return _helpers.result_or_out(result, out)
374376

375377

376378
@normalizer
377379
def any(
378-
a: ArrayLike, axis: AxisLike = None, out=None, keepdims=NoValue, *, where=NoValue
380+
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue, *, where=NoValue
379381
):
380382
result = _reductions.any(a, axis=axis, where=where, keepdims=keepdims)
381383
return _helpers.result_or_out(result, out)
@@ -388,13 +390,13 @@ def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False):
388390

389391

390392
@normalizer
391-
def cumsum(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out=None):
393+
def cumsum(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out: Optional[NDArray] = None):
392394
result = _reductions.cumsum(a, axis=axis, dtype=dtype)
393395
return _helpers.result_or_out(result, out)
394396

395397

396398
@normalizer
397-
def cumprod(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out=None):
399+
def cumprod(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out: Optional[NDArray] = None):
398400
result = _reductions.cumprod(a, axis=axis, dtype=dtype)
399401
return _helpers.result_or_out(result, out)
400402

@@ -407,7 +409,7 @@ def quantile(
407409
a: ArrayLike,
408410
q: ArrayLike,
409411
axis: AxisLike = None,
410-
out=None,
412+
out: Optional[NDArray] = None,
411413
overwrite_input=False,
412414
method="linear",
413415
keepdims=False,

torch_np/_helpers.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ def cast_and_broadcast(tensors, out, casting):
2929
if out is None:
3030
return tensors
3131
else:
32-
from ._ndarray import asarray, ndarray
33-
34-
if not isinstance(out, ndarray):
35-
raise TypeError("Return arrays must be of ArrayType")
36-
3732
tensors = _util.cast_and_broadcast(
3833
tensors, out.dtype.type.torch_dtype, out.shape, casting
3934
)
@@ -72,11 +67,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
7267
result_tensor is placed into the out array.
7368
This weirdness is used e.g. in `np.percentile`
7469
"""
75-
from ._ndarray import asarray, ndarray
76-
7770
if out_array is not None:
78-
if not isinstance(out_array, ndarray):
79-
raise TypeError("Return arrays must be of ArrayType")
8071
if result_tensor.shape != out_array.shape:
8172
can_fit = result_tensor.numel() == 1 and out_array.ndim == 0
8273
if promote_scalar and can_fit:
@@ -90,7 +81,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
9081
out_tensor.copy_(result_tensor)
9182
return out_array
9283
else:
93-
return asarray(result_tensor)
84+
return array_from(result_tensor)
9485

9586

9687
def array_from(tensor, base=None):

torch_np/_normalizations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DTypeLike = typing.TypeVar("DTypeLike")
1313
SubokLike = typing.TypeVar("SubokLike")
1414
AxisLike = typing.TypeVar("AxisLike")
15+
NDArray = typing.TypeVar("NDarray")
1516

1617
# annotate e.g. atleast_1d(*arys)
1718
UnpackedSeqArrayLike = typing.TypeVar("UnpackedSeqArrayLike")
@@ -60,11 +61,24 @@ def normalize_axis_like(arg, name=None):
6061
return arg
6162

6263

64+
def normalize_ndarray(arg, name=None):
65+
if arg is None:
66+
return arg
67+
68+
from ._ndarray import ndarray
69+
70+
if not isinstance(arg, ndarray):
71+
raise TypeError("'out' must be an array")
72+
return arg
73+
74+
75+
6376
normalizers = {
6477
ArrayLike: normalize_array_like,
6578
Optional[ArrayLike]: normalize_optional_array_like,
6679
Sequence[ArrayLike]: normalize_seq_array_like,
6780
UnpackedSeqArrayLike: normalize_seq_array_like, # cf handling in normalize
81+
Optional[NDArray]: normalize_ndarray,
6882
DTypeLike: normalize_dtype,
6983
SubokLike: normalize_subok_like,
7084
AxisLike: normalize_axis_like,

torch_np/_unary_ufuncs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from . import _helpers
66
from ._detail import _unary_ufuncs
7-
from ._normalizations import ArrayLike, DTypeLike, SubokLike, normalizer
7+
from ._normalizations import ArrayLike, DTypeLike, SubokLike, NDArray, normalizer
8+
from typing import Optional
89

910
__all__ = [
1011
name for name in dir(_unary_ufuncs) if not name.startswith("_") and name != "torch"
@@ -21,7 +22,7 @@ def deco_unary_ufunc(torch_func):
2122
def wrapped(
2223
x: ArrayLike,
2324
/,
24-
out=None,
25+
out: Optional[NDArray] = None,
2526
*,
2627
where=True,
2728
casting="same_kind",

torch_np/_wrapper.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
from ._detail import implementations as _impl
1414
from ._ndarray import array, asarray, maybe_set_base, ndarray
1515

16-
### XXX: order the imports DAG
1716
from ._normalizations import (
1817
ArrayLike,
1918
DTypeLike,
2019
SubokLike,
2120
UnpackedSeqArrayLike,
21+
NDArray,
2222
normalizer,
2323
)
24+
from typing import Optional
2425

2526
from . import _dtypes, _helpers, _decorators # isort: skip # XXX
2627

@@ -108,9 +109,6 @@ def _concat_check(tup, dtype, out):
108109
raise ValueError("need at least one array to concatenate")
109110

110111
if out is not None:
111-
if not isinstance(out, ndarray):
112-
raise ValueError("'out' must be an array")
113-
114112
if dtype is not None:
115113
# mimic numpy
116114
raise TypeError(
@@ -123,7 +121,7 @@ def _concat_check(tup, dtype, out):
123121
def concatenate(
124122
ar_tuple: Sequence[ArrayLike],
125123
axis=0,
126-
out=None,
124+
out: Optional[NDArray]=None,
127125
dtype: DTypeLike = None,
128126
casting="same_kind",
129127
):
@@ -173,7 +171,7 @@ def column_stack(
173171
def stack(
174172
arrays: Sequence[ArrayLike],
175173
axis=0,
176-
out=None,
174+
out: Optional[NDArray] = None,
177175
*,
178176
dtype: DTypeLike = None,
179177
casting="same_kind",
@@ -666,7 +664,7 @@ def percentile(
666664
a,
667665
q,
668666
axis=None,
669-
out=None,
667+
out: Optional[NDArray] = None,
670668
overwrite_input=False,
671669
method="linear",
672670
keepdims=False,
@@ -678,7 +676,7 @@ def percentile(
678676
)
679677

680678

681-
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
679+
def median(a, axis=None, out: Optional[NDArray] = None, overwrite_input=False, keepdims=False):
682680
return _funcs.quantile(
683681
a, 0.5, axis=axis, overwrite_input=overwrite_input, out=out, keepdims=keepdims
684682
)
@@ -691,7 +689,7 @@ def inner(a: ArrayLike, b: ArrayLike, /):
691689

692690

693691
@normalizer
694-
def outer(a: ArrayLike, b: ArrayLike, out=None):
692+
def outer(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
695693
result = torch.outer(a, b)
696694
return _helpers.result_or_out(result, out)
697695

@@ -704,7 +702,7 @@ def nanmean(
704702
a: ArrayLike,
705703
axis=None,
706704
dtype: DTypeLike = None,
707-
out=None,
705+
out: Optional[NDArray] = None,
708706
keepdims=NoValue,
709707
*,
710708
where=NoValue,
@@ -847,13 +845,13 @@ def isrealobj(x: ArrayLike):
847845

848846

849847
@normalizer
850-
def isneginf(x: ArrayLike, out=None):
848+
def isneginf(x: ArrayLike, out: Optional[NDArray] = None):
851849
result = torch.isneginf(x, out=out)
852850
return _helpers.array_from(result)
853851

854852

855853
@normalizer
856-
def isposinf(x: ArrayLike, out=None):
854+
def isposinf(x: ArrayLike, out: Optional[NDArray] = None):
857855
result = torch.isposinf(x, out=out)
858856
return _helpers.array_from(result)
859857

0 commit comments

Comments
 (0)