Skip to content

Commit 58fb850

Browse files
committed
Fix TensorVariable __rmatmul__
1 parent 326cb2e commit 58fb850

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

pytensor/tensor/variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def __matmul__(left, right):
652652
return at.math.matmul(left, right)
653653

654654
def __rmatmul__(right, left):
655-
return at.math.matmul(right, left)
655+
return at.math.matmul(left, right)
656656

657657
def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
658658
"""See :func:`pytensor.tensor.math.sum`."""

tests/tensor/test_variable.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.compile.mode import get_default_mode
1111
from pytensor.graph.basic import Constant, equal_computations
1212
from pytensor.tensor import get_vector_length
13-
from pytensor.tensor.basic import as_tensor, constant
13+
from pytensor.tensor.basic import constant
1414
from pytensor.tensor.elemwise import DimShuffle
1515
from pytensor.tensor.math import dot, eq, matmul
1616
from pytensor.tensor.shape import Shape
@@ -98,10 +98,15 @@ def test_infix_matmul_method():
9898
assert equal_computations([res], [exp_res])
9999

100100
X_val = np.arange(2 * 3).reshape((2, 3))
101-
res = as_tensor(X_val) @ y
101+
res = X_val @ y
102102
exp_res = matmul(X_val, y)
103103
assert equal_computations([res], [exp_res])
104104

105+
y_val = np.arange(3)
106+
res = X @ y_val
107+
exp_res = matmul(X, y_val)
108+
assert equal_computations([res], [exp_res])
109+
105110

106111
def test_empty_list_indexing():
107112
ynp = np.zeros((2, 2))[:, []]

0 commit comments

Comments
 (0)