Skip to content

Commit d8ef2ad

Browse files
fix bug in ag_solve_triangular rewrite
1 parent 7a82a3f commit d8ef2ad

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
register_specialize,
1212
register_stabilize,
1313
)
14-
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve
14+
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
1515

1616

1717
logger = logging.getLogger(__name__)
@@ -52,29 +52,28 @@ def inv_as_solve(fgraph, node):
5252
@node_rewriter([Solve])
5353
def tag_solve_triangular(fgraph, node):
5454
"""
55-
If a general solve() is applied to the output of a cholesky op, then
55+
If any solve() is applied to the output of a cholesky op, then
5656
replace it with a triangular solve.
5757
5858
"""
5959
if isinstance(node.op, Solve):
60-
if node.op.assume_a == "gen":
61-
A, b = node.inputs # result is solution Ax=b
62-
if A.owner and isinstance(A.owner.op, Cholesky):
63-
if A.owner.op.lower:
64-
return [Solve(assume_a="sym", lower=True)(A, b)]
60+
A, b = node.inputs # result is solution Ax=b
61+
if A.owner and isinstance(A.owner.op, Cholesky):
62+
if A.owner.op.lower:
63+
return [SolveTriangular(lower=True)(A, b)]
64+
else:
65+
return [SolveTriangular(lower=False)(A, b)]
66+
if (
67+
A.owner
68+
and isinstance(A.owner.op, DimShuffle)
69+
and A.owner.op.new_order == (1, 0)
70+
):
71+
(A_T,) = A.owner.inputs
72+
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
73+
if A_T.owner.op.lower:
74+
return [SolveTriangular(lower=False)(A, b)]
6575
else:
66-
return [Solve(assume_a="sym", lower=False)(A, b)]
67-
if (
68-
A.owner
69-
and isinstance(A.owner.op, DimShuffle)
70-
and A.owner.op.new_order == (1, 0)
71-
):
72-
(A_T,) = A.owner.inputs
73-
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
74-
if A_T.owner.op.lower:
75-
return [Solve(assume_a="sym", lower=False)(A, b)]
76-
else:
77-
return [Solve(assume_a="sym", lower=True)(A, b)]
76+
return [SolveTriangular(lower=True)(A, b)]
7877

7978

8079
@register_canonicalize

tests/tensor/rewriting/test_linalg.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy.linalg
33
import pytest
44
import scipy.linalg
5+
from numpy.testing import assert_allclose
56

67
import pytensor
78
from pytensor import function
@@ -12,7 +13,7 @@
1213
from pytensor.tensor.math import _allclose
1314
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
1415
from pytensor.tensor.rewriting.linalg import inv_as_solve
15-
from pytensor.tensor.slinalg import Cholesky, Solve, solve
16+
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
1617
from pytensor.tensor.type import dmatrix, matrix, vector
1718
from tests import unittest_tools as utt
1819
from tests.test_rop import break_op
@@ -85,21 +86,36 @@ def test_tag_solve_triangular():
8586
cholesky_lower = Cholesky(lower=True)
8687
cholesky_upper = Cholesky(lower=False)
8788
A = matrix("A")
88-
x = vector("x")
89+
x = matrix("x")
90+
8991
L = cholesky_lower(A)
9092
U = cholesky_upper(A)
9193
b1 = solve(L, x)
9294
b2 = solve(U, x)
9395
f = pytensor.function([A, x], b1)
96+
97+
X = np.random.normal(size=(10, 10))
98+
X = X @ X.T
99+
X_chol = np.linalg.cholesky(X)
100+
eye = np.eye(10)
101+
94102
if config.mode != "FAST_COMPILE":
95-
for node in f.maker.fgraph.toposort():
96-
if isinstance(node.op, Solve):
97-
assert node.op.assume_a != "gen" and node.op.lower
103+
toposort = f.maker.fgraph.toposort()
104+
op_list = [node.op for node in toposort]
105+
106+
assert not any(isinstance(op, Solve) for op in op_list)
107+
assert any(isinstance(op, SolveTriangular) for op in op_list)
108+
109+
assert_allclose(f(X, eye) @ X_chol, eye, atol=1e-8)
110+
98111
f = pytensor.function([A, x], b2)
112+
99113
if config.mode != "FAST_COMPILE":
100-
for node in f.maker.fgraph.toposort():
101-
if isinstance(node.op, Solve):
102-
assert node.op.assume_a != "gen" and not node.op.lower
114+
toposort = f.maker.fgraph.toposort()
115+
op_list = [node.op for node in toposort]
116+
assert not any(isinstance(op, Solve) for op in op_list)
117+
assert any(isinstance(op, SolveTriangular) for op in op_list)
118+
assert_allclose(f(X, eye).T @ X_chol, eye, atol=1e-8)
103119

104120

105121
def test_matrix_inverse_solve():

0 commit comments

Comments
 (0)