Skip to content

Commit f26148f

Browse files
committed
lint
1 parent 68f6c6f commit f26148f

File tree

7 files changed

+52
-48
lines changed

7 files changed

+52
-48
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,6 @@ def product(*args, **kwargs):
774774
raise NotImplementedError
775775

776776

777-
778777
def put(a, ind, v, mode="raise"):
779778
raise NotImplementedError
780779

torch_np/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,5 @@
1313
alltrue = all
1414
sometrue = any
1515

16-
inf = float('inf')
17-
nan = float('nan')
18-
16+
inf = float("inf")
17+
nan = float("nan")

torch_np/_decorators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ def axis_none_ravel_wrapper(func):
122122
similar logic.
123123
124124
"""
125+
125126
@functools.wraps(func)
126127
def wrapped(a, axis=None, *args, **kwds):
127-
from ._ndarray import ndarray, asarray
128+
from ._ndarray import asarray, ndarray
129+
128130
tensor = asarray(a).get()
129131

130132
# standardize the axis argument

torch_np/_detail/_reductions.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,12 @@ def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
172172
# 1. no keepdims
173173
# 2. axis=None ravels (cf concatenate)
174174

175+
175176
def cumprod(tensor, axis=None, dtype=None):
176177
if dtype == torch.bool:
177178
dtype = _scalar_types.default_int_type.dtype
178179
if dtype is None:
179-
dtype=tensor.dtype
180+
dtype = tensor.dtype
180181

181182
result = tensor.cumprod(axis=axis, dtype=dtype)
182183

@@ -187,7 +188,7 @@ def cumsum(tensor, axis=None, dtype=None):
187188
if dtype == torch.bool:
188189
dtype = _scalar_types.default_int_type.dtype
189190
if dtype is None:
190-
dtype=tensor.dtype
191+
dtype = tensor.dtype
191192

192193
result = tensor.cumsum(axis=axis, dtype=dtype)
193194

@@ -205,23 +206,25 @@ def average(a_tensor, axis, w_tensor):
205206

206207
# axis
207208
if axis is None:
208-
(a_tensor, w_tensor), axis = _util.axis_none_ravel(a_tensor, w_tensor, axis=axis)
209+
(a_tensor, w_tensor), axis = _util.axis_none_ravel(
210+
a_tensor, w_tensor, axis=axis
211+
)
209212

210213
# axis & weights
211214
if a_tensor.shape != w_tensor.shape:
212215
if axis is None:
213216
raise TypeError(
214-
"Axis must be specified when shapes of a and weights "
215-
"differ.")
217+
"Axis must be specified when shapes of a and weights " "differ."
218+
)
216219
if w_tensor.ndim != 1:
217-
raise TypeError(
218-
"1D weights expected when shapes of a and weights differ.")
220+
raise TypeError("1D weights expected when shapes of a and weights differ.")
219221
if w_tensor.shape[0] != a_tensor.shape[axis]:
220-
raise ValueError(
221-
"Length of weights not compatible with specified axis.")
222+
raise ValueError("Length of weights not compatible with specified axis.")
222223

223224
# setup weight to broadcast along axis
224-
w_tensor = torch.broadcast_to(w_tensor, (a_tensor.ndim-1)*(1,) + w_tensor.shape)
225+
w_tensor = torch.broadcast_to(
226+
w_tensor, (a_tensor.ndim - 1) * (1,) + w_tensor.shape
227+
)
225228
w_tensor = w_tensor.swapaxes(-1, axis)
226229

227230
# do the work
@@ -232,7 +235,6 @@ def average(a_tensor, axis, w_tensor):
232235
return result, denominator
233236

234237

235-
236238
def quantile(a_tensor, q_tensor, axis, method):
237239

238240
if (0 > q_tensor).any() or (q_tensor > 1).any():

torch_np/_ndarray.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@
22

33
import torch
44

5-
from ._detail import _util
6-
from ._detail import _reductions
7-
from . import _helpers
8-
from . import _dtypes
9-
from . import _unary_ufuncs
10-
from . import _binary_ufuncs
11-
12-
from ._decorators import (emulate_out_arg, axis_keepdims_wrapper,
13-
dtype_to_torch, axis_none_ravel_wrapper)
14-
15-
from ._decorators import NoValue
5+
from . import _binary_ufuncs, _dtypes, _helpers, _unary_ufuncs
6+
from ._decorators import (
7+
NoValue,
8+
axis_keepdims_wrapper,
9+
axis_none_ravel_wrapper,
10+
dtype_to_torch,
11+
emulate_out_arg,
12+
)
13+
from ._detail import _reductions, _util
1614

1715
newaxis = None
1816

@@ -290,8 +288,12 @@ def nonzero(self):
290288
var = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.var)))
291289
std = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.std)))
292290

293-
cumprod = emulate_out_arg(axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumprod)))
294-
cumsum = emulate_out_arg(axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumsum)))
291+
cumprod = emulate_out_arg(
292+
axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumprod))
293+
)
294+
cumsum = emulate_out_arg(
295+
axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumsum))
296+
)
295297

296298
### indexing ###
297299
def __getitem__(self, *args, **kwds):

torch_np/_wrapper.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,6 @@ def prod(
645645
)
646646

647647

648-
649648
def cumprod(a, axis=None, dtype=None, out=None):
650649
arr = asarray(a)
651650
return arr.cumprod(axis=axis, dtype=dtype, out=out)
@@ -658,7 +657,9 @@ def cumsum(a, axis=None, dtype=None, out=None):
658657
arr = asarray(a)
659658
return arr.cumsum(axis=axis, dtype=dtype, out=out)
660659

661-
#YYY: pattern : ddof
660+
661+
# YYY: pattern : ddof
662+
662663

663664
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue):
664665
arr = asarray(a)
@@ -679,7 +680,7 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=NoValue):
679680
if weights is None:
680681
result = mean(a, axis=axis, keepdims=keepdims)
681682
if returned:
682-
scl = result.dtype.type(a.size/result.size)
683+
scl = result.dtype.type(a.size / result.size)
683684
return result, scl
684685
return result
685686

@@ -713,8 +714,9 @@ def percentile(
713714
*,
714715
interpolation=None,
715716
):
716-
return quantile(a, asarray(q) / 100., axis, out, overwrite_input, method,
717-
keepdims=keepdims)
717+
return quantile(
718+
a, asarray(q) / 100.0, axis, out, overwrite_input, method, keepdims=keepdims
719+
)
718720

719721

720722
def quantile(
@@ -736,13 +738,14 @@ def quantile(
736738

737739
# keepdims
738740
if keepdims:
739-
result = _util.apply_keepdims(result, axis, a_tensor.ndim)
741+
result = _util.apply_keepdims(result, axis, a_tensor.ndim)
740742
return _helpers.result_or_out(result, out, promote_scalar=True)
741743

742744

743745
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
744-
return quantile(a, 0.5, axis=axis, overwrite_input=overwrite_input,
745-
out=out, keepdims=keepdims)
746+
return quantile(
747+
a, 0.5, axis=axis, overwrite_input=overwrite_input, out=out, keepdims=keepdims
748+
)
746749

747750

748751
@asarray_replacer()

torch_np/tests/test_reductions.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def setup_method(self):
608608
self.func = np.mean
609609
self.allowed_axes = [0, 1, 2, -1, -2, (0, 1), (1, 0), (0, 1, 2), (1, -1, 0)]
610610

611+
611612
class TestSumGeneric(_GenericReductionsTestMixin, _GenericHasOutTestMixin):
612613
def setup_method(self):
613614
self.func = np.sum
@@ -635,8 +636,7 @@ def setup_method(self):
635636
class TestVarGeneric(_GenericReductionsTestMixin, _GenericHasOutTestMixin):
636637
def setup_method(self):
637638
self.func = np.var
638-
self.allowed_axes = [0, 1, 2, -1, -2,
639-
(0, 1), (1, 0), (0, 1, 2), (1, -1, 0)]
639+
self.allowed_axes = [0, 1, 2, -1, -2, (0, 1), (1, 0), (0, 1, 2), (1, -1, 0)]
640640

641641

642642
class _GenericCumSumProdTestMixin:
@@ -648,36 +648,34 @@ class _GenericCumSumProdTestMixin:
648648
649649
To use: subclass, define self.func and self.allowed_axes.
650650
"""
651+
651652
def test_bad_axis(self):
652653
# Basic check of functionality
653654
m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
654655

655-
assert_raises(TypeError, self.func, m, axis='foo')
656+
assert_raises(TypeError, self.func, m, axis="foo")
656657
assert_raises(np.AxisError, self.func, m, axis=3)
657-
assert_raises(TypeError, self.func,
658-
m, axis=np.array([[1], [2]]))
658+
assert_raises(TypeError, self.func, m, axis=np.array([[1], [2]]))
659659
assert_raises(TypeError, self.func, m, axis=1.5)
660660

661661
# TODO: add tests with np.int32(3) etc, when implemented
662662

663663
def test_array_axis(self):
664664
a = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
665-
assert_equal(self.func(a, axis=np.array(-1)),
666-
self.func(a, axis=-1))
665+
assert_equal(self.func(a, axis=np.array(-1)), self.func(a, axis=-1))
667666

668667
with assert_raises(TypeError):
669668
self.func(a, axis=np.array([1, 2]))
670669

671670
def test_axis_empty_generic(self):
672671
a = np.array([[0, 0, 1], [1, 0, 1]])
673-
assert_array_equal(self.func(a, axis=None),
674-
self.func(a.ravel(), axis=0))
672+
assert_array_equal(self.func(a, axis=None), self.func(a.ravel(), axis=0))
675673

676674
def test_axis_bad_tuple(self):
677675
# Basic check of functionality
678676
m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
679677
with assert_raises(TypeError):
680-
self.func(m, axis=(1, 1))
678+
self.func(m, axis=(1, 1))
681679

682680

683681
class TestCumProdGeneric(_GenericCumSumProdTestMixin):
@@ -688,4 +686,3 @@ def setup_method(self):
688686
class TestCumSumGeneric(_GenericCumSumProdTestMixin):
689687
def setup_method(self):
690688
self.func = np.cumsum
691-

0 commit comments

Comments
 (0)