Skip to content

Fix bug in tag_solve_triangular rewrite #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
register_specialize,
register_stabilize,
)
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,29 +52,28 @@ def inv_as_solve(fgraph, node):
@node_rewriter([Solve])
def tag_solve_triangular(fgraph, node):
Copy link
Member

@ricardoV94 ricardoV94 Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT? The test name should also be changed then

Suggested change
def tag_solve_triangular(fgraph, node):
def solve_cholesky_to_solve_triangular(fgraph, node):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed, "tag" doesn't mean anything. But we shouldn't call it "solve_cholesky", because there's a separate Op called SolveCholesky. Maybe generic_solve_to_solve_triangular?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

"""
If a general solve() is applied to the output of a cholesky op, then
If any solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.

"""
if isinstance(node.op, Solve):
if node.op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [Solve(assume_a="sym", lower=True)(A, b)]
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [SolveTriangular(lower=True)(A, b)]
else:
return [SolveTriangular(lower=False)(A, b)]
if (
A.owner
and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)
):
(A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower:
return [SolveTriangular(lower=False)(A, b)]
else:
return [Solve(assume_a="sym", lower=False)(A, b)]
if (
A.owner
and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)
):
(A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower:
return [Solve(assume_a="sym", lower=False)(A, b)]
else:
return [Solve(assume_a="sym", lower=True)(A, b)]
return [SolveTriangular(lower=True)(A, b)]


@register_canonicalize
Expand Down
32 changes: 24 additions & 8 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy.linalg
import pytest
import scipy.linalg
from numpy.testing import assert_allclose

import pytensor
from pytensor import function
Expand All @@ -12,7 +13,7 @@
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
from pytensor.tensor.type import dmatrix, matrix, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
Expand Down Expand Up @@ -85,21 +86,36 @@ def test_tag_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = matrix("A")
x = vector("x")
x = matrix("x")

L = cholesky_lower(A)
U = cholesky_upper(A)
b1 = solve(L, x)
b2 = solve(U, x)
f = pytensor.function([A, x], b1)

X = np.random.normal(size=(10, 10))
X = X @ X.T
X_chol = np.linalg.cholesky(X)
eye = np.eye(10)

if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.assume_a != "gen" and node.op.lower
toposort = f.maker.fgraph.toposort()
op_list = [node.op for node in toposort]

assert not any(isinstance(op, Solve) for op in op_list)
assert any(isinstance(op, SolveTriangular) for op in op_list)

assert_allclose(f(X, eye) @ X_chol, eye, atol=1e-8)

f = pytensor.function([A, x], b2)

if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.assume_a != "gen" and not node.op.lower
toposort = f.maker.fgraph.toposort()
op_list = [node.op for node in toposort]
assert not any(isinstance(op, Solve) for op in op_list)
assert any(isinstance(op, SolveTriangular) for op in op_list)
assert_allclose(f(X, eye).T @ X_chol, eye, atol=1e-8)


def test_matrix_inverse_solve():
Expand Down