Skip to content

Commit 868b56b

Browse files
committed
MAINT: review comments
1 parent 344b3f7 commit 868b56b

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

torch_np/_funcs_impl.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,14 +1260,13 @@ def einsum(*operands, out=None, optimize=False, **kwargs):
12601260
raise TypeError("unknown arguments: ", kwargs)
12611261

12621262
# parse arrays and normalize them
1263-
if isinstance(operands[0], str):
1264-
# ("ij->", arrays) format
1265-
sublist_format = False
1266-
subscripts, array_operands = operands[0], operands[1:]
1267-
else:
1263+
sublist_format = not isinstance(operands[0], str)
1264+
if sublist_format:
12681265
# op, str, op, str ... format: normalize every other argument
1269-
sublist_format = True
12701266
array_operands = operands[:-1][::2]
1267+
else:
1268+
# ("ij->", arrays) format
1269+
subscripts, array_operands = operands[0], operands[1:]
12711270

12721271
tensors = [normalize_array_like(op) for op in array_operands]
12731272
target_dtype = (
@@ -1291,8 +1290,6 @@ def einsum(*operands, out=None, optimize=False, **kwargs):
12911290
else:
12921291
result = torch.einsum(subscripts, *tensors)
12931292

1294-
1295-
12961293
result = maybe_copy_to(out, result)
12971294
return wrap_tensors(result)
12981295

0 commit comments

Comments
 (0)