diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index bc3eef6fca..f61cdc52b7 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -215,8 +215,8 @@ def psd_solve_with_chol(fgraph, node): # N.B. this can be further reduced to a yet-unwritten cho_solve Op # __if__ no other Op makes use of the L matrix during the # stabilization - Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2) - x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2) + Li_b = solve_triangular(L, b, lower=True, b_ndim=2) + x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2) return [x] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9cdb69ce6b..54ee110f6d 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -241,6 +241,33 @@ def test_local_det_chol(): assert not any(isinstance(node, Det) for node in nodes) +def test_psd_solve_with_chol(): + X = matrix("X") + X.tag.psd = True + X_inv = pt.linalg.solve(X, pt.identity_like(X)) + + f = function([X], X_inv, mode="FAST_RUN") + + nodes = f.maker.fgraph.apply_nodes + + assert not any(isinstance(node.op, Solve) for node in nodes) + assert any(isinstance(node.op, Cholesky) for node in nodes) + assert any(isinstance(node.op, SolveTriangular) for node in nodes) + + # Numeric test + rng = np.random.default_rng(sum(map(ord, "test_psd_solve_with_chol"))) + + L = rng.normal(size=(5, 5)).astype(config.floatX) + X_psd = L @ L.T + X_psd_inv = f(X_psd) + assert_allclose( + X_psd_inv, + np.linalg.inv(X_psd), + atol=1e-4 if config.floatX == "float32" else 1e-8, + rtol=1e-4 if config.floatX == "float32" else 1e-8, + ) + + class TestBatchedVectorBSolveToMatrixBSolve: rewrite_name = "batched_vector_b_solve_to_matrix_b_solve"