Skip to content

Commit 7e9f49c

Browse files
committed
MAINT: remove local imports from einsum
1 parent 211a461 commit 7e9f49c

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

torch_np/_funcs_impl.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@
1818
from . import _dtypes_impl
1919
from . import _reductions as _impl
2020
from . import _util
21-
from ._normalizations import (
21+
22+
# these imports are for einsum only
23+
from ._normalizations import ( # isort: skip
2224
ArrayLike,
2325
AxisLike,
2426
CastingModes,
2527
DTypeLike,
2628
NDArray,
2729
NotImplementedType,
2830
OutArray,
31+
maybe_copy_to,
2932
normalize_array_like,
33+
normalize_casting,
34+
normalize_dtype,
35+
wrap_tensors,
3036
)
3137

3238
# ###### array creation routines
@@ -1233,12 +1239,6 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
12331239
# Have to manually normalize *operands and **kwargs, following the NumPy signature
12341240

12351241
from ._ndarray import ndarray
1236-
from ._normalizations import (
1237-
maybe_copy_to,
1238-
normalize_casting,
1239-
normalize_dtype,
1240-
wrap_tensors,
1241-
)
12421242

12431243
dtype = normalize_dtype(dtype)
12441244
casting = normalize_casting(casting)
@@ -1276,7 +1276,7 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
12761276

12771277
is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
12781278
if is_short_int:
1279-
target_dtype, result_dtype = torch.int64, target_dtype
1279+
target_dtype = torch.int64
12801280

12811281
tensors = _util.typecast_tensors(tensors, target_dtype, casting)
12821282

0 commit comments

Comments
 (0)