Skip to content

Commit 00fea0e

Browse files
Fix einsum failing with repeated inputs (#1260)
* fixed Einsum failing with repeated inputs * Optimise the _ensure_not_equal function Co-authored-by: Ricardo Vieira <[email protected]> * Fix einsum failing on repeated inputs * Fix einsum failing with repeated inputs * Added regression test for repeated inputs to the einsum function * Fix for failing test Co-authored-by: Ricardo Vieira <[email protected]> --------- Co-authored-by: Ricardo Vieira <[email protected]>
1 parent c0860f8 commit 00fea0e

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

pytensor/tensor/einsum.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]:
417417
return tuple(pairwise(reversed(range(n))))
418418

419419

420+
def _ensure_not_equal(elements):
421+
"""
422+
Ensures that any pair in a list of elements are not the same object. If a pair of elements is found to be equal, then one of them is converted to a copy.
423+
"""
424+
elements = list(elements)
425+
for i, elem1 in enumerate(elements[:-1]):
426+
for j, elem2 in enumerate(elements[i + 1 :], start=i + 1):
427+
if elem1 is elem2:
428+
elements[j] = elem1.copy()
429+
return elements
430+
431+
420432
def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
421433
"""
422434
Multiplication and summation of tensors using the Einstein summation convention.
@@ -553,7 +565,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
553565
"If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. "
554566
)
555567

556-
tensor_operands = [as_tensor(operand) for operand in operands]
568+
tensor_operands = _ensure_not_equal([as_tensor(operand) for operand in operands])
557569
shapes = [operand.type.shape for operand in tensor_operands]
558570

559571
path: PATH

tests/tensor/test_einsum.py

+13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor import Mode, config, function
99
from pytensor.graph import FunctionGraph
1010
from pytensor.graph.op import HasInnerGraph
11+
from pytensor.tensor import matrix
1112
from pytensor.tensor.basic import moveaxis
1213
from pytensor.tensor.blockwise import Blockwise
1314
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
@@ -281,3 +282,15 @@ def test_threeway_mul(static_length):
281282
out.eval({x: x_test, y: y_test, z: z_test}),
282283
np.full((3,), fill_value=6),
283284
)
285+
286+
287+
def test_repeated_inputs():
288+
x = matrix("x")
289+
out_repeated = einsum("ij,ij->i", x, x)
290+
out_copy = einsum("ij,ij->i", x, x.copy())
291+
292+
x_test = np.array([[1, 2], [3, 4]]).astype(x.dtype)
293+
294+
np.testing.assert_allclose(
295+
out_repeated.eval({x: x_test}), out_copy.eval({x: x_test})
296+
)

0 commit comments

Comments
 (0)