diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 2831ab02ab..c02d3f5a23 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -11,7 +11,7 @@ register_specialize, register_stabilize, ) -from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve +from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve logger = logging.getLogger(__name__) @@ -50,31 +50,30 @@ def inv_as_solve(fgraph, node): @register_stabilize @register_canonicalize @node_rewriter([Solve]) -def tag_solve_triangular(fgraph, node): +def generic_solve_to_solve_triangular(fgraph, node): """ - If a general solve() is applied to the output of a cholesky op, then + If any solve() is applied to the output of a cholesky op, then replace it with a triangular solve. """ if isinstance(node.op, Solve): - if node.op.assume_a == "gen": - A, b = node.inputs # result is solution Ax=b - if A.owner and isinstance(A.owner.op, Cholesky): - if A.owner.op.lower: - return [Solve(assume_a="sym", lower=True)(A, b)] + A, b = node.inputs # result is solution Ax=b + if A.owner and isinstance(A.owner.op, Cholesky): + if A.owner.op.lower: + return [SolveTriangular(lower=True)(A, b)] + else: + return [SolveTriangular(lower=False)(A, b)] + if ( + A.owner + and isinstance(A.owner.op, DimShuffle) + and A.owner.op.new_order == (1, 0) + ): + (A_T,) = A.owner.inputs + if A_T.owner and isinstance(A_T.owner.op, Cholesky): + if A_T.owner.op.lower: + return [SolveTriangular(lower=False)(A, b)] else: - return [Solve(assume_a="sym", lower=False)(A, b)] - if ( - A.owner - and isinstance(A.owner.op, DimShuffle) - and A.owner.op.new_order == (1, 0) - ): - (A_T,) = A.owner.inputs - if A_T.owner and isinstance(A_T.owner.op, Cholesky): - if A_T.owner.op.lower: - return [Solve(assume_a="sym", lower=False)(A, b)] - else: - return [Solve(assume_a="sym", lower=True)(A, b)] + return [SolveTriangular(lower=True)(A, b)] @register_canonicalize diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9ec182cb21..c28388d11c 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -2,6 +2,7 @@ import numpy.linalg import pytest import scipy.linalg +from numpy.testing import assert_allclose import pytensor from pytensor import function @@ -12,7 +13,7 @@ from pytensor.tensor.math import _allclose from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse from pytensor.tensor.rewriting.linalg import inv_as_solve -from pytensor.tensor.slinalg import Cholesky, Solve, solve +from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve from pytensor.tensor.type import dmatrix, matrix, vector from tests import unittest_tools as utt from tests.test_rop import break_op @@ -81,25 +82,46 @@ def test_transinv_to_invtrans(): assert node.inputs[0].name == "X" -def test_tag_solve_triangular(): +def test_generic_solve_to_solve_triangular(): cholesky_lower = Cholesky(lower=True) cholesky_upper = Cholesky(lower=False) A = matrix("A") - x = vector("x") + x = matrix("x") + L = cholesky_lower(A) U = cholesky_upper(A) b1 = solve(L, x) b2 = solve(U, x) f = pytensor.function([A, x], b1) + + X = np.random.normal(size=(10, 10)).astype(config.floatX) + X = X @ X.T + X_chol = np.linalg.cholesky(X) + eye = np.eye(10, dtype=config.floatX) + if config.mode != "FAST_COMPILE": - for node in f.maker.fgraph.toposort(): - if isinstance(node.op, Solve): - assert node.op.assume_a != "gen" and node.op.lower + toposort = f.maker.fgraph.toposort() + op_list = [node.op for node in toposort] + + assert not any(isinstance(op, Solve) for op in op_list) + assert any(isinstance(op, SolveTriangular) for op in op_list) + + assert_allclose( + f(X, eye) @ X_chol, eye, atol=1e-8 if config.floatX.endswith("64") else 1e-4 + ) + f = pytensor.function([A, x], b2) + if config.mode != "FAST_COMPILE": - for node in f.maker.fgraph.toposort(): - if isinstance(node.op, Solve): - assert node.op.assume_a != "gen" and not node.op.lower + toposort = f.maker.fgraph.toposort() + op_list = [node.op for node in toposort] + assert not any(isinstance(op, Solve) for op in op_list) + assert any(isinstance(op, SolveTriangular) for op in op_list) + assert_allclose( + f(X, eye).T @ X_chol, + eye, + atol=1e-8 if config.floatX.endswith("64") else 1e-4, + ) def test_matrix_inverse_solve():