Skip to content

Commit 211a461

Browse files
committed
einsum: forward optimize= to torch.backends
1 parent 1dd74bf commit 211a461

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

torch_np/_funcs_impl.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,19 +1280,27 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
12801280

12811281
tensors = _util.typecast_tensors(tensors, target_dtype, casting)
12821282

1283-
if sublist_format:
1284-
# recombine operands
1285-
sublists = operands[1::2]
1286-
has_sublistout = len(operands) % 2 == 1
1287-
if has_sublistout:
1288-
sublistout = operands[-1]
1289-
operands = list(itertools.chain(*zip(tensors, sublists)))
1290-
if has_sublistout:
1291-
operands.append(sublistout)
1292-
1293-
result = torch.einsum(*operands)
1294-
else:
1295-
result = torch.einsum(subscripts, *tensors)
1283+
try:
1284+
# set the global state to handle the optimize=... argument, restore on exit
1285+
old_strategy = torch.backends.opt_einsum.strategy
1286+
torch.backends.opt_einsum.strategy = optimize
1287+
1288+
if sublist_format:
1289+
# recombine operands
1290+
sublists = operands[1::2]
1291+
has_sublistout = len(operands) % 2 == 1
1292+
if has_sublistout:
1293+
sublistout = operands[-1]
1294+
operands = list(itertools.chain(*zip(tensors, sublists)))
1295+
if has_sublistout:
1296+
operands.append(sublistout)
1297+
1298+
result = torch.einsum(*operands)
1299+
else:
1300+
result = torch.einsum(subscripts, *tensors)
1301+
1302+
finally:
1303+
torch.backends.opt_einsum.strategy = old_strategy
12961304

12971305
result = maybe_copy_to(out, result)
12981306
return wrap_tensors(result)

0 commit comments

Comments
 (0)