Skip to content

Commit 679e053

Browse files
committed
lint
1 parent 3da48a2 commit 679e053

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

torch_np/_detail/_flips.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from . import _scalar_types, _util
77

8+
89
def flip(m_tensor, axis=None):
910
# XXX: semantic difference: np.flip returns a view, torch.flip copies
1011
if axis is None:
@@ -39,7 +40,7 @@ def rollaxis(tensor, axis, start=0):
3940
start += n
4041
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
4142
if not (0 <= start < n + 1):
42-
raise _util.AxisError(msg % ('start', -n, 'start', n + 1, start))
43+
raise _util.AxisError(msg % ("start", -n, "start", n + 1, start))
4344
if axis < start:
4445
# it's been removed
4546
start -= 1

torch_np/_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
dtype_to_torch,
1111
emulate_out_arg,
1212
)
13-
from ._detail import _reductions, _util, _flips
13+
from ._detail import _flips, _reductions, _util
1414

1515
newaxis = None
1616

torch_np/_wrapper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
from . import _dtypes, _helpers
11-
from ._detail import _reductions, _util, _flips
11+
from ._detail import _flips, _reductions, _util
1212
from ._ndarray import (
1313
array,
1414
asarray,
@@ -476,15 +476,16 @@ def broadcast_arrays(*args, subok=False):
476476

477477
@asarray_replacer()
478478
def moveaxis(a, source, destination):
479-
source = _util.normalize_axis_tuple(source, a.ndim, 'source')
480-
destination = _util.normalize_axis_tuple(destination, a.ndim, 'destination')
479+
source = _util.normalize_axis_tuple(source, a.ndim, "source")
480+
destination = _util.normalize_axis_tuple(destination, a.ndim, "destination")
481481
return asarray(torch.moveaxis(a, source, destination))
482482

483483

484484
def swapaxis(a, axis1, axis2):
485485
arr = asarray(a)
486486
return arr.swapaxes(axis1, axis2)
487487

488+
488489
@asarray_replacer()
489490
def rollaxis(a, axis, start=0):
490491
return _flips.rollaxis(a, axis, start)
@@ -665,8 +666,10 @@ def prod(
665666
axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where
666667
)
667668

669+
668670
product = prod
669671

672+
670673
def cumprod(a, axis=None, dtype=None, out=None):
671674
arr = asarray(a)
672675
return arr.cumprod(axis=axis, dtype=dtype, out=out)

0 commit comments

Comments
 (0)