Skip to content

Commit be59ba3

Browse files
committed
MAINT: remove isort:skip directives (circ imports are well hidden now)
1 parent fe9011d commit be59ba3

File tree

8 files changed

+83
-42
lines changed

8 files changed

+83
-42
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from ._wrapper import * # isort: skip # XXX: currently this prevents circular imports
21
from . import random
32
from ._binary_ufuncs import *
43
from ._detail._index_tricks import *
@@ -8,6 +7,7 @@
87
from ._getlimits import errstate, finfo, iinfo
98
from ._ndarray import array, asarray, can_cast, ndarray, newaxis, result_type
109
from ._unary_ufuncs import *
10+
from ._wrapper import *
1111

1212
# from . import testing
1313

torch_np/_binary_ufuncs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from . import _helpers
2-
from ._detail import _binary_ufuncs
3-
from ._normalizations import ArrayLike, DTypeLike, SubokLike, NDArray, normalizer
41
from typing import Optional
52

3+
from . import _helpers
4+
from ._detail import _binary_ufuncs
5+
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
66

77
__all__ = [
88
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"

torch_np/_detail/_reductions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import torch
88

9+
import typing
10+
911
from . import _dtypes_impl, _util
1012

1113
NoValue = None
@@ -26,7 +28,9 @@ def wrapped(tensor, axis, *args, **kwds):
2628
if axis is not None:
2729
if not isinstance(axis, (list, tuple)):
2830
if not isinstance(axis, typing.SupportsIndex):
29-
raise TypeError(f"{type(axis)=}, but should be a list/tuple or support operator.index()")
31+
raise TypeError(
32+
f"{type(axis)=}, but should be a list/tuple or support operator.index()"
33+
)
3034
axis = (axis,)
3135
axis = _util.normalize_axis_tuple(axis, tensor.ndim)
3236

torch_np/_funcs.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
UnpackedSeqArrayLike,
1515
normalizer,
1616
)
17-
from typing import Optional
1817

1918

2019
@normalizer
@@ -59,7 +58,14 @@ def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1):
5958

6059

6160
@normalizer
62-
def trace(a: ArrayLike, offset=0, axis1=0, axis2=1, dtype: DTypeLike = None, out: Optional[NDArray] = None):
61+
def trace(
62+
a: ArrayLike,
63+
offset=0,
64+
axis1=0,
65+
axis2=1,
66+
dtype: DTypeLike = None,
67+
out: Optional[NDArray] = None,
68+
):
6369
result = _impl.trace(a, offset, axis1, axis2, dtype)
6470
return _helpers.result_or_out(result, out)
6571

@@ -213,7 +219,7 @@ def imag(a: ArrayLike):
213219

214220

215221
@normalizer
216-
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray]=None):
222+
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
217223
result = _impl.round(a, decimals)
218224
return _helpers.result_or_out(result, out)
219225

@@ -233,7 +239,7 @@ def sum(
233239
a: ArrayLike,
234240
axis: AxisLike = None,
235241
dtype: DTypeLike = None,
236-
out: Optional[NDArray]=None,
242+
out: Optional[NDArray] = None,
237243
keepdims=NoValue,
238244
initial=NoValue,
239245
where=NoValue,
@@ -249,7 +255,7 @@ def prod(
249255
a: ArrayLike,
250256
axis: AxisLike = None,
251257
dtype: DTypeLike = None,
252-
out: Optional[NDArray]=None,
258+
out: Optional[NDArray] = None,
253259
keepdims=NoValue,
254260
initial=NoValue,
255261
where=NoValue,
@@ -268,7 +274,7 @@ def mean(
268274
a: ArrayLike,
269275
axis: AxisLike = None,
270276
dtype: DTypeLike = None,
271-
out: Optional[NDArray]=None,
277+
out: Optional[NDArray] = None,
272278
keepdims=NoValue,
273279
*,
274280
where=NoValue,
@@ -284,7 +290,7 @@ def var(
284290
a: ArrayLike,
285291
axis: AxisLike = None,
286292
dtype: DTypeLike = None,
287-
out: Optional[NDArray]=None,
293+
out: Optional[NDArray] = None,
288294
ddof=0,
289295
keepdims=NoValue,
290296
*,
@@ -301,7 +307,7 @@ def std(
301307
a: ArrayLike,
302308
axis: AxisLike = None,
303309
dtype: DTypeLike = None,
304-
out: Optional[NDArray]=None,
310+
out: Optional[NDArray] = None,
305311
ddof=0,
306312
keepdims=NoValue,
307313
*,
@@ -314,13 +320,25 @@ def std(
314320

315321

316322
@normalizer
317-
def argmin(a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, *, keepdims=NoValue):
323+
def argmin(
324+
a: ArrayLike,
325+
axis: AxisLike = None,
326+
out: Optional[NDArray] = None,
327+
*,
328+
keepdims=NoValue,
329+
):
318330
result = _reductions.argmin(a, axis=axis, keepdims=keepdims)
319331
return _helpers.result_or_out(result, out)
320332

321333

322334
@normalizer
323-
def argmax(a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, *, keepdims=NoValue):
335+
def argmax(
336+
a: ArrayLike,
337+
axis: AxisLike = None,
338+
out: Optional[NDArray] = None,
339+
*,
340+
keepdims=NoValue,
341+
):
324342
result = _reductions.argmax(a, axis=axis, keepdims=keepdims)
325343
return _helpers.result_or_out(result, out)
326344

@@ -362,22 +380,34 @@ def amin(
362380

363381

364382
@normalizer
365-
def ptp(a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue):
383+
def ptp(
384+
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
385+
):
366386
result = _reductions.ptp(a, axis=axis, keepdims=keepdims)
367387
return _helpers.result_or_out(result, out)
368388

369389

370390
@normalizer
371391
def all(
372-
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue, *, where=NoValue
392+
a: ArrayLike,
393+
axis: AxisLike = None,
394+
out: Optional[NDArray] = None,
395+
keepdims=NoValue,
396+
*,
397+
where=NoValue,
373398
):
374399
result = _reductions.all(a, axis=axis, where=where, keepdims=keepdims)
375400
return _helpers.result_or_out(result, out)
376401

377402

378403
@normalizer
379404
def any(
380-
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue, *, where=NoValue
405+
a: ArrayLike,
406+
axis: AxisLike = None,
407+
out: Optional[NDArray] = None,
408+
keepdims=NoValue,
409+
*,
410+
where=NoValue,
381411
):
382412
result = _reductions.any(a, axis=axis, where=where, keepdims=keepdims)
383413
return _helpers.result_or_out(result, out)
@@ -390,13 +420,23 @@ def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False):
390420

391421

392422
@normalizer
393-
def cumsum(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out: Optional[NDArray] = None):
423+
def cumsum(
424+
a: ArrayLike,
425+
axis: AxisLike = None,
426+
dtype: DTypeLike = None,
427+
out: Optional[NDArray] = None,
428+
):
394429
result = _reductions.cumsum(a, axis=axis, dtype=dtype)
395430
return _helpers.result_or_out(result, out)
396431

397432

398433
@normalizer
399-
def cumprod(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out: Optional[NDArray] = None):
434+
def cumprod(
435+
a: ArrayLike,
436+
axis: AxisLike = None,
437+
dtype: DTypeLike = None,
438+
out: Optional[NDArray] = None,
439+
):
400440
result = _reductions.cumprod(a, axis=axis, dtype=dtype)
401441
return _helpers.result_or_out(result, out)
402442

torch_np/_normalizations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def normalize_ndarray(arg, name=None):
7272
return arg
7373

7474

75-
7675
normalizers = {
7776
ArrayLike: normalize_array_like,
7877
Optional[ArrayLike]: normalize_optional_array_like,

torch_np/_unary_ufuncs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# from ._detail import _ufunc_impl
33

44

5+
from typing import Optional
6+
57
from . import _helpers
68
from ._detail import _unary_ufuncs
7-
from ._normalizations import ArrayLike, DTypeLike, SubokLike, NDArray, normalizer
8-
from typing import Optional
9+
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
910

1011
__all__ = [
1112
name for name in dir(_unary_ufuncs) if not name.startswith("_") and name != "torch"

torch_np/_wrapper.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,18 @@
88

99
import torch
1010

11-
from . import _funcs
11+
from . import _decorators, _dtypes, _funcs, _helpers
1212
from ._detail import _dtypes_impl, _flips, _reductions, _util
1313
from ._detail import implementations as _impl
1414
from ._ndarray import array, asarray, maybe_set_base, ndarray
15-
1615
from ._normalizations import (
1716
ArrayLike,
1817
DTypeLike,
18+
NDArray,
1919
SubokLike,
2020
UnpackedSeqArrayLike,
21-
NDArray,
2221
normalizer,
2322
)
24-
from typing import Optional
25-
26-
from . import _dtypes, _helpers, _decorators # isort: skip # XXX
27-
2823

2924
# Things to decide on (punt for now)
3025
#
@@ -121,7 +116,7 @@ def _concat_check(tup, dtype, out):
121116
def concatenate(
122117
ar_tuple: Sequence[ArrayLike],
123118
axis=0,
124-
out: Optional[NDArray]=None,
119+
out: Optional[NDArray] = None,
125120
dtype: DTypeLike = None,
126121
casting="same_kind",
127122
):
@@ -171,7 +166,7 @@ def column_stack(
171166
def stack(
172167
arrays: Sequence[ArrayLike],
173168
axis=0,
174-
out: Optional[NDArray] = None,
169+
out: Optional[NDArray] = None,
175170
*,
176171
dtype: DTypeLike = None,
177172
casting="same_kind",
@@ -664,7 +659,7 @@ def percentile(
664659
a,
665660
q,
666661
axis=None,
667-
out: Optional[NDArray] = None,
662+
out: Optional[NDArray] = None,
668663
overwrite_input=False,
669664
method="linear",
670665
keepdims=False,
@@ -676,7 +671,9 @@ def percentile(
676671
)
677672

678673

679-
def median(a, axis=None, out: Optional[NDArray] = None, overwrite_input=False, keepdims=False):
674+
def median(
675+
a, axis=None, out: Optional[NDArray] = None, overwrite_input=False, keepdims=False
676+
):
680677
return _funcs.quantile(
681678
a, 0.5, axis=axis, overwrite_input=overwrite_input, out=out, keepdims=keepdims
682679
)
@@ -689,7 +686,7 @@ def inner(a: ArrayLike, b: ArrayLike, /):
689686

690687

691688
@normalizer
692-
def outer(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
689+
def outer(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
693690
result = torch.outer(a, b)
694691
return _helpers.result_or_out(result, out)
695692

@@ -702,7 +699,7 @@ def nanmean(
702699
a: ArrayLike,
703700
axis=None,
704701
dtype: DTypeLike = None,
705-
out: Optional[NDArray] = None,
702+
out: Optional[NDArray] = None,
706703
keepdims=NoValue,
707704
*,
708705
where=NoValue,
@@ -845,13 +842,13 @@ def isrealobj(x: ArrayLike):
845842

846843

847844
@normalizer
848-
def isneginf(x: ArrayLike, out: Optional[NDArray] = None):
845+
def isneginf(x: ArrayLike, out: Optional[NDArray] = None):
849846
result = torch.isneginf(x, out=out)
850847
return _helpers.array_from(result)
851848

852849

853850
@normalizer
854-
def isposinf(x: ArrayLike, out: Optional[NDArray] = None):
851+
def isposinf(x: ArrayLike, out: Optional[NDArray] = None):
855852
result = torch.isposinf(x, out=out)
856853
return _helpers.array_from(result)
857854

torch_np/random.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
77
"""
88
from math import sqrt
9+
from typing import Optional
910

1011
import torch
1112

12-
from ._detail import _dtypes_impl, _util
1313
from . import _helpers
14-
from ._normalizations import normalizer, ArrayLike
15-
from typing import Optional
14+
from ._detail import _dtypes_impl, _util
15+
from ._normalizations import ArrayLike, normalizer
1616

1717
_default_dtype = _dtypes_impl.default_float_dtype
1818

@@ -96,7 +96,7 @@ def randint(low, high=None, size=None):
9696

9797

9898
@normalizer
99-
def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike]=None):
99+
def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None):
100100

101101
# https://stackoverflow.com/questions/59461811/random-choice-with-pytorch
102102
if a.numel() == 1:

0 commit comments

Comments
 (0)