Skip to content

Commit 6186340

Browse files
committed
restore Dot22 rewrites
1 parent 258de42 commit 6186340

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pytensor.graph.rewriting.basic import node_rewriter
44
from pytensor.tensor import basic as at
5+
from pytensor.tensor.blas import Dot22
56
from pytensor.tensor.elemwise import DimShuffle
67
from pytensor.tensor.math import Dot, Prod, dot, log
78
from pytensor.tensor.math import pow as at_pow
@@ -31,12 +32,12 @@ def transinv_to_invtrans(fgraph, node):
3132

3233

3334
@register_stabilize
34-
@node_rewriter([Dot])
35+
@node_rewriter([Dot, Dot22])
3536
def inv_as_solve(fgraph, node):
3637
"""
3738
This utilizes a boolean `symmetric` tag on the matrices.
3839
"""
39-
if isinstance(node.op, Dot):
40+
if isinstance(node.op, (Dot, Dot22)):
4041
l, r = node.inputs
4142
if l.owner and isinstance(l.owner.op, MatrixInverse):
4243
return [solve(l.owner.inputs[0], r)]
@@ -122,7 +123,7 @@ def cholesky_ldotlt(fgraph, node):
122123
return
123124

124125
A = node.inputs[0]
125-
if not (A.owner and isinstance(A.owner.op, Dot)):
126+
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
126127
return
127128

128129
l, r = A.owner.inputs

0 commit comments

Comments
 (0)