Skip to content

Commit 4b21903

Browse files
committed
MAINT: remove local imports from einsum
1 parent 211a461 commit 4b21903

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

torch_np/_funcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,4 @@ def __getitem__(self, item):
6969
s_ = IndexExpression(maketuple=False)
7070

7171
__all__ += ["index_exp", "s_"]
72+

torch_np/_funcs_impl.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
NotImplementedType,
2828
OutArray,
2929
normalize_array_like,
30+
31+
# these imports are for einsum
32+
maybe_copy_to,
33+
normalize_casting,
34+
normalize_dtype,
35+
wrap_tensors,
3036
)
3137

3238
# ###### array creation routines
@@ -1233,12 +1239,6 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
12331239
# Have to manually normalize *operands and **kwargs, following the NumPy signature
12341240

12351241
from ._ndarray import ndarray
1236-
from ._normalizations import (
1237-
maybe_copy_to,
1238-
normalize_casting,
1239-
normalize_dtype,
1240-
wrap_tensors,
1241-
)
12421242

12431243
dtype = normalize_dtype(dtype)
12441244
casting = normalize_casting(casting)
@@ -1276,7 +1276,7 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
12761276

12771277
is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
12781278
if is_short_int:
1279-
target_dtype, result_dtype = torch.int64, target_dtype
1279+
target_dtype = torch.int64
12801280

12811281
tensors = _util.typecast_tensors(tensors, target_dtype, casting)
12821282

0 commit comments

Comments
 (0)