Skip to content

Commit aa2d240

Browse files
committed
MAINT: remove local imports from einsum
1 parent 211a461 commit aa2d240

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

torch_np/_funcs_impl.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@
2929
normalize_array_like,
3030
)
3131

32+
# these imports are for einsum only
33+
from ._normalizations import (
34+
maybe_copy_to,
35+
normalize_casting,
36+
normalize_dtype,
37+
wrap_tensors,
38+
) # isort: skip
39+
40+
3241
# ###### array creation routines
3342

3443

@@ -1233,12 +1242,6 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
12331242
# Have to manually normalize *operands and **kwargs, following the NumPy signature
12341243

12351244
from ._ndarray import ndarray
1236-
from ._normalizations import (
1237-
maybe_copy_to,
1238-
normalize_casting,
1239-
normalize_dtype,
1240-
wrap_tensors,
1241-
)
12421245

12431246
dtype = normalize_dtype(dtype)
12441247
casting = normalize_casting(casting)
@@ -1276,7 +1279,7 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
12761279

12771280
is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
12781281
if is_short_int:
1279-
target_dtype, result_dtype = torch.int64, target_dtype
1282+
target_dtype = torch.int64
12801283

12811284
tensors = _util.typecast_tensors(tensors, target_dtype, casting)
12821285

0 commit comments

Comments
 (0)