Skip to content

Commit b7f035f

Browse files
committed
prostrate myself at the face of black
1 parent 22130af commit b7f035f

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

torch_np/_dtypes_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def nep50_to_tensors(x1, x2, handle_weaks):
123123
if cat_weak == 1 and cat_not_weak == 1:
124124
# integers
125125
iinfo = torch.iinfo(not_weak.dtype)
126-
if weak < iinfo.min or weak > iinfo.max:
126+
if weak < iinfo.min or weak > iinfo.max:
127127
raise OverflowError(
128128
f"Python integer {weak} out of bounds for {not_weak.dtype}"
129129
)

torch_np/_ufuncs.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,18 @@ def _ufunc_postprocess(result, out, casting):
3232
]
3333

3434

35-
NEP50_FUNCS = ("add", "subtract", "multiply", "floor_divide", "remainder", "bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift")
36-
35+
NEP50_FUNCS = (
36+
"add",
37+
"subtract",
38+
"multiply",
39+
"floor_divide",
40+
"remainder",
41+
"bitwise_and",
42+
"bitwise_or",
43+
"bitwise_xor",
44+
"left_shift",
45+
"right_shift",
46+
)
3747

3848

3949
def deco_binary_ufunc(torch_func):
@@ -58,8 +68,10 @@ def wrapped(
5868
signature=None,
5969
extobj=None,
6070
):
61-
flag = (torch_func.__name__ in NEP50_FUNCS and
62-
_dtypes_impl.default_dtypes == _dtypes_impl.default_dtypes_numpy)
71+
flag = (
72+
torch_func.__name__ in NEP50_FUNCS
73+
and _dtypes_impl.default_dtypes == _dtypes_impl.default_dtypes_numpy
74+
)
6375
x1, x2 = _dtypes_impl.nep50_to_tensors(x1, x2, flag)
6476

6577
if dtype is None:

torch_np/tests/test_nep50_examples.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
try:
66
import numpy as _np
7+
78
HAVE_NUMPY = True
89
except ImportError:
910
HAVE_NUMPY = False
@@ -99,14 +100,21 @@ def test_nep50_exceptions(example):
99100
assert result.dtype == new.dtype
100101

101102

102-
103103
# ### Directly compare to numpy ###
104104

105105
weaks = [True, 1, 2.0, 3j]
106-
non_weaks = [tnp.asarray(True),
107-
tnp.uint8(1), tnp.int8(1), tnp.int32(1), tnp.int64(1),
108-
tnp.float32(1), tnp.float64(1),
109-
tnp.complex64(1), tnp.complex128(1)]
106+
non_weaks = [
107+
tnp.asarray(True),
108+
tnp.uint8(1),
109+
tnp.int8(1),
110+
tnp.int32(1),
111+
tnp.int64(1),
112+
tnp.float32(1),
113+
tnp.float64(1),
114+
tnp.complex64(1),
115+
tnp.complex128(1),
116+
]
117+
110118

111119
@pytest.mark.skipif(not HAVE_NUMPY, reason="NumPy not found")
112120
@pytest.mark.parametrize("scalar, array", itertools.product(weaks, non_weaks))
@@ -121,6 +129,5 @@ def test_direct_compare(scalar, array):
121129
assert result.dtype == result_numpy.dtype
122130
assert result == result_numpy
123131

124-
125132
finally:
126133
_np._set_promotion_state(state)

0 commit comments

Comments
 (0)