-
Notifications
You must be signed in to change notification settings - Fork 134
Fix bug in tag_solve_triangular
rewrite
#383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in tag_solve_triangular
rewrite
#383
Conversation
pytensor/tensor/rewriting/linalg.py
Outdated
@@ -52,29 +52,28 @@ def inv_as_solve(fgraph, node): | |||
@node_rewriter([Solve]) | |||
def tag_solve_triangular(fgraph, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT? The test name should also be changed then
def tag_solve_triangular(fgraph, node): | |
def solve_cholesky_to_solve_triangular(fgraph, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed, "tag" doesn't mean anything. But we shouldn't call it "solve_cholesky", because there's a separate Op
called SolveCholesky
. Maybe generic_solve_to_solve_triangular
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good
tag_solve_triangular
rewritetag_solve_triangular
rewrite
You got one of those pesky float32 test failures. LGTM otherwise. Just left a suggestion of giving a more descriptive name to the rewrite |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #383 +/- ##
==========================================
- Coverage 80.46% 80.46% -0.01%
==========================================
Files 156 156
Lines 45515 45514 -1
Branches 11149 11148 -1
==========================================
- Hits 36625 36624 -1
Misses 6688 6688
Partials 2202 2202
|
Motivation for these changes
Currently, the
tag_solve_trangular
rewrite looks for graphs like this:And replaces then with this:
The
Solve(assume_a='sym', lower=True)
solver was incorrectly assumed to be a triangular solver. In fact, it is a symmetric solver that decomposes the symmetricA
matrix into a triangular matrix. But given a triangular matrix, it produces incorrect results. This can be verified directly by testingL @ L_inv = eye
:Actually, even before the rewrite this code fails, so there's something deeper going on:
But one step at a time. This PR replaces
Solve(assume_a='sym', lower=True)
withSolveTriangular
in thetag_solve_triangular
rewrite, and adds a test for correct computation to the rewrite. That computational test can be put somewhere else once we track down what's going on withSolve
more generally, but it's a good stop-gap for now.Checklist
Major / Breaking Changes
New features
Bugfixes
Cholesky
Op
followed by aSolve
op will no longer be incorrectly computed.Documentation
Maintenance
Closes #382