Skip to content

Commit 723b0ad

Browse files
committed
MAINT: add from __future__ import annotations
1 parent 083a0e6 commit 723b0ad

File tree

7 files changed

+22
-12
lines changed

7 files changed

+22
-12
lines changed

torch_np/_funcs_impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
# Contents of this module ends up in the main namespace via _funcs.py
77
# where type annotations are used in conjunction with the @normalizer decorator.
8+
from __future__ import annotations
89

910
import builtins
1011
import math

torch_np/_normalizations.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
22
"""
3+
from __future__ import annotations
4+
35
import functools
46
import operator
57
import typing
@@ -97,15 +99,15 @@ def normalize_outarray(arg, name=None):
9799

98100

99101
normalizers = {
100-
ArrayLike: normalize_array_like,
101-
Optional[ArrayLike]: normalize_optional_array_like,
102-
Sequence[ArrayLike]: normalize_seq_array_like,
103-
Optional[NDArray]: normalize_ndarray,
104-
Optional[OutArray]: normalize_outarray,
105-
NDArray: normalize_ndarray,
106-
DTypeLike: normalize_dtype,
107-
SubokLike: normalize_subok_like,
108-
AxisLike: normalize_axis_like,
102+
'ArrayLike': normalize_array_like,
103+
'Optional[ArrayLike]': normalize_optional_array_like,
104+
'Sequence[ArrayLike]': normalize_seq_array_like,
105+
'Optional[NDArray]': normalize_ndarray,
106+
'Optional[OutArray]': normalize_outarray,
107+
'NDArray': normalize_ndarray,
108+
'DTypeLike': normalize_dtype,
109+
'SubokLike': normalize_subok_like,
110+
'AxisLike': normalize_axis_like,
109111
}
110112

111113

torch_np/_ufuncs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Optional
24

35
import torch

torch_np/linalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
3+
14
import functools
25
import math
36
from typing import Sequence

torch_np/random.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
Q: default dtype is float64 in numpy
66
77
"""
8+
from __future__ import annotations
9+
810
from math import sqrt
911
from typing import Optional
1012

torch_np/tests/numpy_tests/core/test_dtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,8 @@ def test_subscript_scalar(self) -> None:
306306
assert np.dtype[Any]
307307

308308

309-
@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
309+
@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.9")
310310
def test_class_getitem_38() -> None:
311311
match = "Type subscription requires python >= 3.9"
312-
with pytest.raises(TypeError, match=match):
312+
with pytest.raises(TypeError): # , match=match):
313313
np.dtype[Any]

torch_np/tests/numpy_tests/core/test_scalar_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_subscript_scalar(self) -> None:
179179
@pytest.mark.parametrize("cls", [np.number, np.complexfloating, np.int64])
180180
def test_class_getitem_38(cls: Type[np.number]) -> None:
181181
match = "Type subscription requires python >= 3.9"
182-
with pytest.raises(TypeError, match=match):
182+
with pytest.raises(TypeError): #, match=match):
183183
cls[Any]
184184

185185

0 commit comments

Comments
 (0)