Skip to content

Commit 89d0523

Browse files
committed
Fix vectorize_node function name
1 parent e180927 commit 89d0523

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pytensor/tensor/math.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2948,9 +2948,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29482948

29492949

29502950
@_vectorize_node.register(Dot)
2951-
def vectorize_node_to_matmul(op, node, batched_x, batched_y):
2951+
def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
29522952
old_x, old_y = node.inputs
29532953
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
2954+
# If original input is equivalent to a matrix-matrix product,
2955+
# return specialized Matmul Op to avoid unnecessary new Ops.
29542956
return matmul(batched_x, batched_y).owner
29552957
else:
29562958
return vectorize_node_fallback(op, node, batched_x, batched_y)

0 commit comments

Comments
 (0)