Skip to content

Commit 51a4d92

Browse files
committed
Fix TensorVariable __rmatmul__
1 parent 326cb2e commit 51a4d92

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

pytensor/tensor/variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def __matmul__(left, right):
652652
return at.math.matmul(left, right)
653653

654654
def __rmatmul__(right, left):
655-
return at.math.matmul(right, left)
655+
return at.math.matmul(left, right)
656656

657657
def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
658658
"""See :func:`pytensor.tensor.math.sum`."""

tests/tensor/test_variable.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,15 @@ def test_infix_matmul_method():
9898
assert equal_computations([res], [exp_res])
9999

100100
X_val = np.arange(2 * 3).reshape((2, 3))
101-
res = as_tensor(X_val) @ y
101+
res = X_val @ y
102102
exp_res = matmul(X_val, y)
103103
assert equal_computations([res], [exp_res])
104104

105+
y_val = np.arange(3)
106+
res = X @ y_val
107+
exp_res = matmul(X, y_val)
108+
assert equal_computations([res], [exp_res])
109+
105110

106111
def test_empty_list_indexing():
107112
ynp = np.zeros((2, 2))[:, []]

0 commit comments

Comments
 (0)