Skip to content

Commit c4dffa5

Browse files
committed
lint
1 parent 71ea537 commit c4dffa5

9 files changed

+227
-85
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from ._detail import _binary_ufuncs
22

3-
__all__ = [name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"]
3+
__all__ = [
4+
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
5+
]
46

57

6-
# TODO: consolidate normalizations
7-
from ._funcs import normalizer, ArrayLike, SubokLike, DTypeLike
8-
from ._detail import _util
98
from . import _helpers
9+
from ._detail import _util
10+
11+
# TODO: consolidate normalizations
12+
from ._funcs import ArrayLike, DTypeLike, SubokLike, normalizer
1013

1114

1215
def deco_binary_ufunc(torch_func):
@@ -15,19 +18,20 @@ def deco_binary_ufunc(torch_func):
1518
Normalize arguments, sort out type casting, broadcasting and delegate to
1619
the pytorch functions for the actual work.
1720
"""
21+
1822
def wrapped(
19-
x1 : ArrayLike,
20-
x2 : ArrayLike,
23+
x1: ArrayLike,
24+
x2: ArrayLike,
2125
/,
2226
out=None,
2327
*,
2428
where=True,
2529
casting="same_kind",
2630
order="K",
27-
dtype: DTypeLike=None,
28-
subok: SubokLike=False,
31+
dtype: DTypeLike = None,
32+
subok: SubokLike = False,
2933
signature=None,
30-
extobj=None
34+
extobj=None,
3135
):
3236
if order != "K" or not where or signature or extobj:
3337
raise NotImplementedError
@@ -47,6 +51,7 @@ def wrapped(
4751

4852
return wrapped
4953

54+
5055
#
5156
# For each torch ufunc implementation, decorate and attach the decorated name
5257
# to this module. Its contents is then exported to the public namespace in __init__.py
@@ -55,6 +60,6 @@ def wrapped(
5560
ufunc = getattr(_binary_ufuncs, name)
5661
decorated = normalizer(deco_binary_ufunc(ufunc))
5762

58-
decorated.__qualname__ = name # XXX: is this really correct?
63+
decorated.__qualname__ = name # XXX: is this really correct?
5964
decorated.__name__ = name
6065
vars()[name] = decorated

torch_np/_decorators.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import operator
32

43
import torch
54

@@ -9,7 +8,6 @@
98
NoValue = None
109

1110

12-
1311
def out_shape_dtype(func):
1412
"""Handle out=... kwarg for ufuncs.
1513
@@ -26,5 +24,3 @@ def wrapped(*args, out=None, **kwds):
2624
return _helpers.result_or_out(result_tensor, out)
2725

2826
return wrapped
29-
30-

torch_np/_detail/_binary_ufuncs.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,43 @@
44

55
import torch
66

7-
from . import _dtypes_impl, _util
8-
9-
10-
from torch import (add, arctan2, bitwise_and, bitwise_or, bitwise_xor, copysign, divide,
11-
float_power, floor_divide, fmax, fmin, fmod, gcd, greater, greater_equal, heaviside,
12-
hypot, lcm, ldexp, less, less_equal, logaddexp, logaddexp2, logical_and,
13-
logical_or, logical_xor, maximum, minimum, remainder, multiply, nextafter, not_equal,
14-
remainder, subtract, divide)
15-
16-
177
# renames
18-
from torch import (eq as equal, pow as power, bitwise_right_shift as right_shift,
19-
bitwise_left_shift as left_shift,)
8+
from torch import add, arctan2, bitwise_and
9+
from torch import bitwise_left_shift as left_shift
10+
from torch import bitwise_or
11+
from torch import bitwise_right_shift as right_shift
12+
from torch import bitwise_xor, copysign, divide
13+
from torch import eq as equal
14+
from torch import (
15+
float_power,
16+
floor_divide,
17+
fmax,
18+
fmin,
19+
fmod,
20+
gcd,
21+
greater,
22+
greater_equal,
23+
heaviside,
24+
hypot,
25+
lcm,
26+
ldexp,
27+
less,
28+
less_equal,
29+
logaddexp,
30+
logaddexp2,
31+
logical_and,
32+
logical_or,
33+
logical_xor,
34+
maximum,
35+
minimum,
36+
multiply,
37+
nextafter,
38+
not_equal,
39+
)
40+
from torch import pow as power
41+
from torch import remainder, subtract
42+
43+
from . import _dtypes_impl, _util
2044

2145

2246
# work around torch limitations w.r.t. numpy
@@ -27,4 +51,3 @@ def matmul(x, y):
2751
y = _util.cast_if_needed(y, dtype)
2852
result = torch.matmul(x, y)
2953
return result
30-

torch_np/_detail/_reductions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313

1414
import functools
1515

16-
1716
############# XXX
1817
### From _util.axis_expand_func
1918

2019

2120
def deco_axis_expand(func):
2221
"""Generically handle axis arguments in reductions."""
22+
2323
@functools.wraps(func)
2424
def wrapped(tensor, axis, *args, **kwds):
2525

@@ -46,11 +46,13 @@ def wrapped(tensor, axis=None, keepdims=NoValue, *args, **kwds):
4646
if keepdims:
4747
result = _util.apply_keepdims(result, axis, tensor.ndim)
4848
return result
49+
4950
return wrapped
5051

5152

5253
def deco_axis_ravel(func):
5354
"""Generically handle 'axis=None ravels' behavior."""
55+
5456
@functools.wraps(func)
5557
def wrapped(tensor, axis, *args, **kwds):
5658
if axis is not None:
@@ -61,6 +63,7 @@ def wrapped(tensor, axis, *args, **kwds):
6163

6264
result = func(tensor, axis=axis, *args, **kwds)
6365
return result
66+
6467
return wrapped
6568

6669

@@ -292,7 +295,6 @@ def cumsum(tensor, axis, dtype=None):
292295
return result
293296

294297

295-
296298
def average(a, axis, weights, returned=False, keepdims=False):
297299
if weights is None:
298300
result, wsum = average_noweights(a, axis, keepdims=keepdims)
@@ -384,7 +386,6 @@ def quantile(a_tensor, q_tensor, axis, method, keepdims=False):
384386

385387
q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype)
386388

387-
388389
# axis=None ravels, so store the originals to reuse with keepdims=True below
389390
ax, ndim = axis, a_tensor.ndim
390391
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)

torch_np/_detail/_unary_ufuncs.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,36 @@
44

55
import torch
66

7-
from torch import (arccos, arccosh, arcsin, arcsinh, arctan, arctanh, ceil,
8-
cos, cosh, deg2rad, exp, exp2, expm1,
9-
floor, isfinite, isinf, isnan, log, log10, log1p, log2, logical_not,
10-
negative, rad2deg, reciprocal, sign, signbit,
11-
sin, sinh, sqrt, square, tan, tanh, trunc)
12-
137
# renames
14-
from torch import (conj_physical as conjugate, round as rint, bitwise_not as invert, rad2deg as degrees,
15-
deg2rad as radians, absolute as fabs, )
8+
from torch import absolute as fabs
9+
from torch import arccos, arccosh, arcsin, arcsinh, arctan, arctanh
10+
from torch import bitwise_not as invert
11+
from torch import ceil
12+
from torch import conj_physical as conjugate
13+
from torch import cos, cosh
14+
from torch import deg2rad
15+
from torch import deg2rad as radians
16+
from torch import (
17+
exp,
18+
exp2,
19+
expm1,
20+
floor,
21+
isfinite,
22+
isinf,
23+
isnan,
24+
log,
25+
log1p,
26+
log2,
27+
log10,
28+
logical_not,
29+
negative,
30+
)
31+
from torch import rad2deg
32+
from torch import rad2deg as degrees
33+
from torch import reciprocal
34+
from torch import round as rint
35+
from torch import sign, signbit, sin, sinh, sqrt, square, tan, tanh, trunc
36+
1637

1738
# special cases: torch does not export these names
1839
def cbrt(x):
@@ -32,4 +53,3 @@ def absolute(x):
3253

3354
abs = absolute
3455
conj = conjugate
35-

0 commit comments

Comments
 (0)