Skip to content

Fix bug in tag_solve_triangular rewrite #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 15, 2023

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jul 14, 2023

Motivation for these changes

Currently, the tag_solve_trangular rewrite looks for graphs like this:

import pytensor
import pytensor.tensor as pt
import numpy as np
from numpy.testing import assert_allclose

A = pt.dmatrix('A')
b = pt.dmatrix('b')
L = pt.linalg.cholesky(A)
L_inv = pt.linalg.solve(L, b, assume_a='gen')
f = pytensor.function([A, b], [L_inv])

pytensor.dprint(L_inv)

Solve{assume_a='gen', lower=True, check_finite=True} [id A]
 ├─ Cholesky{lower=True, destructive=False, on_error='raise'} [id B]
 │  └─ A [id C]
 └─ b [id D]

And replaces then with this:

pytensor.dprint(f)

Solve{assume_a='sym', lower=True, check_finite=True} [id A] 1
 ├─ Cholesky{lower=True, destructive=False, on_error='raise'} [id B] 0
 │  └─ A [id C]
 └─ b [id D]

The Solve(assume_a='sym', lower=True) solver was incorrectly assumed to be a triangular solver. In fact, it is a symmetric solver that decomposes the symmetric A matrix into a triangular matrix. But given a triangular matrix, it produces incorrect results. This can be verified directly by testing L @ L_inv = eye:

n = 3
Z = np.random.normal(size=(n, n))
P = Z @ Z.T
P_chol = np.linalg.cholesky(P)
eye = np.eye(n)

assert_allclose(f(P, eye) @ P_chol, eye) # fails

Actually, even before the rewrite this code fails, so there's something deeper going on:

assert_allclose(f(L_inv.eval({A:P, b:eye}) @ P_chol, eye) # also fails

But one step at a time. This PR replaces Solve(assume_a='sym', lower=True) with SolveTriangular in the tag_solve_triangular rewrite, and adds a test for correct computation to the rewrite. That computational test can be put somewhere else once we track down what's going on with Solve more generally, but it's a good stop-gap for now.

Checklist

Major / Breaking Changes

  • ...

New features

  • ...

Bugfixes

  • Graphs with a Cholesky Op followed by a Solve op will no longer be incorrectly computed.

Documentation

  • ...

Maintenance

  • ...

Closes #382

@@ -52,29 +52,28 @@ def inv_as_solve(fgraph, node):
@node_rewriter([Solve])
def tag_solve_triangular(fgraph, node):
Copy link
Member

@ricardoV94 ricardoV94 Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT? The test name should also be changed then

Suggested change
def tag_solve_triangular(fgraph, node):
def solve_cholesky_to_solve_triangular(fgraph, node):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed, "tag" doesn't mean anything. But we shouldn't call it "solve_cholesky", because there's a separate Op called SolveCholesky. Maybe generic_solve_to_solve_triangular?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

@ricardoV94 ricardoV94 added bug Something isn't working graph rewriting linalg Linear algebra labels Jul 14, 2023
@ricardoV94 ricardoV94 changed the title fix bug in tag_solve_triangular rewrite Fix bug in tag_solve_triangular rewrite Jul 14, 2023
@ricardoV94
Copy link
Member

You got one of those pesky float32 test failures. LGTM otherwise. Just left a suggestion of giving a more descriptive name to the rewrite

@codecov-commenter
Copy link

codecov-commenter commented Jul 15, 2023

Codecov Report

Merging #383 (455707e) into main (7a82a3f) will decrease coverage by 0.01%.
The diff coverage is 69.23%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #383      +/-   ##
==========================================
- Coverage   80.46%   80.46%   -0.01%     
==========================================
  Files         156      156              
  Lines       45515    45514       -1     
  Branches    11149    11148       -1     
==========================================
- Hits        36625    36624       -1     
  Misses       6688     6688              
  Partials     2202     2202              
Impacted Files Coverage Δ
pytensor/tensor/rewriting/linalg.py 71.96% <69.23%> (-0.26%) ⬇️

@ricardoV94 ricardoV94 merged commit 9be43d0 into pymc-devs:main Jul 15, 2023
@jessegrabowski jessegrabowski deleted the fix-tag-solve-triangular branch July 22, 2023 06:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph rewriting linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: tag_solve_triangular doesn't use a triangular solver
3 participants