From 709cecb0a774cb3b496703b2e644f1abf6bcbf9c Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 00:14:39 +0100 Subject: [PATCH 1/4] Use `solve_triangular` instead of in `psd_solve_with_chol` --- pytensor/tensor/rewriting/linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index bc3eef6fca..f61cdc52b7 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -215,8 +215,8 @@ def psd_solve_with_chol(fgraph, node): # N.B. this can be further reduced to a yet-unwritten cho_solve Op # __if__ no other Op makes use of the L matrix during the # stabilization - Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2) - x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2) + Li_b = solve_triangular(L, b, lower=True, b_ndim=2) + x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2) return [x] From 4d83328aa2690d23ce4c91c7705c3f83084b80f8 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 01:04:00 +0100 Subject: [PATCH 2/4] Add unittest for `psd_solve_with_chol` --- tests/tensor/rewriting/test_linalg.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9cdb69ce6b..83974082b4 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -241,6 +241,26 @@ def test_local_det_chol(): assert not any(isinstance(node, Det) for node in nodes) +def test_psd_solve_with_chol(): + X = matrix("X") + X.tag.psd = True + X_inv = pt.linalg.solve(X, pt.identity_like(X)) + + f = function([X], X_inv) + + nodes = f.maker.fgraph.apply_nodes + + assert not any(isinstance(node.op, Solve) for node in nodes) + assert any(isinstance(node.op, Cholesky) for node in nodes) + assert any(isinstance(node.op, SolveTriangular) for node in nodes) + + # Numeric test + L = np.random.randn(5, 5).astype(config.floatX) + X_psd = L @ L.T + X_psd_inv = f(X_psd) + assert_allclose(X_psd_inv, np.linalg.inv(X_psd)) + + class TestBatchedVectorBSolveToMatrixBSolve: rewrite_name = "batched_vector_b_solve_to_matrix_b_solve" From 47d290b7242740945653d1c574ea14ac09c735a4 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 01:32:20 +0100 Subject: [PATCH 3/4] Specify `mode=FAST_RUN` in test --- tests/tensor/rewriting/test_linalg.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 83974082b4..b39deb73f2 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -246,7 +246,7 @@ def test_psd_solve_with_chol(): X.tag.psd = True X_inv = pt.linalg.solve(X, pt.identity_like(X)) - f = function([X], X_inv) + f = function([X], X_inv, mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes @@ -255,7 +255,9 @@ def test_psd_solve_with_chol(): assert any(isinstance(node.op, SolveTriangular) for node in nodes) # Numeric test - L = np.random.randn(5, 5).astype(config.floatX) + rng = np.random.default_rng(sum(map(ord, "test_psd_solve_with_chol"))) + + L = rng.normal(size=(5, 5)).astype(config.floatX) X_psd = L @ L.T X_psd_inv = f(X_psd) assert_allclose(X_psd_inv, np.linalg.inv(X_psd)) From 2b823eefc80e997dbb32588f3dd9daedac149b08 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 15:32:03 +0100 Subject: [PATCH 4/4] Relax `test_psd_solve_with_chol` `atol` and `rtol` for half-precision tests --- tests/tensor/rewriting/test_linalg.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index b39deb73f2..54ee110f6d 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -260,7 +260,12 @@ def test_psd_solve_with_chol(): L = rng.normal(size=(5, 5)).astype(config.floatX) X_psd = L @ L.T X_psd_inv = f(X_psd) - assert_allclose(X_psd_inv, np.linalg.inv(X_psd)) + assert_allclose( + X_psd_inv, + np.linalg.inv(X_psd), + atol=1e-4 if config.floatX == "float32" else 1e-8, + rtol=1e-4 if config.floatX == "float32" else 1e-8, + ) class TestBatchedVectorBSolveToMatrixBSolve: