diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 660c16d387..e119b6de11 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -417,6 +417,18 @@ def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]: return tuple(pairwise(reversed(range(n)))) +def _ensure_not_equal(elements): + """ + Ensures that any pair in a list of elements are not the same object. If a pair of elements is found to be equal, then one of them is converted to a copy. + """ + elements = list(elements) + for i, elem1 in enumerate(elements[:-1]): + for j, elem2 in enumerate(elements[i + 1 :], start=i + 1): + if elem1 is elem2: + elements[j] = elem1.copy() + return elements + + def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable: """ Multiplication and summation of tensors using the Einstein summation convention. @@ -553,7 +565,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar "If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. " ) - tensor_operands = [as_tensor(operand) for operand in operands] + tensor_operands = _ensure_not_equal([as_tensor(operand) for operand in operands]) shapes = [operand.type.shape for operand in tensor_operands] path: PATH diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index ba8e354518..951e9a0c54 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -8,6 +8,7 @@ from pytensor import Mode, config, function from pytensor.graph import FunctionGraph from pytensor.graph.op import HasInnerGraph +from pytensor.tensor import matrix from pytensor.tensor.basic import moveaxis from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum @@ -281,3 +282,15 @@ def test_threeway_mul(static_length): out.eval({x: x_test, y: y_test, z: z_test}), np.full((3,), fill_value=6), ) + + +def test_repeated_inputs(): + x = matrix("x") + out_repeated = einsum("ij,ij->i", x, x) + out_copy = einsum("ij,ij->i", x, x.copy()) + + x_test = np.array([[1, 2], [3, 4]]).astype(x.dtype) + + np.testing.assert_allclose( + out_repeated.eval({x: x_test}), out_copy.eval({x: x_test}) + )