-
Notifications
You must be signed in to change notification settings - Fork 129
Add rewrite to lift linear algebra through certain linalg ops #622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I tracked down the bug we hit during the hackathon. We were importing Now the problem is that
So our batched tests for this rewrite currently failed for the batched
Thoughts? |
Call pt.vectorize on the test from the 2D core case? Does the rewrite still hold as is though? |
The rewrite works fine on the 2d case, it just fails on the batch case because the way This is actually an artifact of how we define kron. It seems like there are two ways:
|
Okay, so maybe let's add a kwarg: Regardless, we have to make sure we do the rewrite correctly (or restrict it to the cases where the logic is correct, although it doesn't sound like it should be too complicated to support both forms?) |
For the scipy form, we need a way for the user to tell us what are batch axes and what are not. If we don't know, the rewrite will break for anything bigger than 2d. I guess if we had a For the numpy form, we technically still need to know this (because |
Isn't our current implementation already the Scipy kron. We don't need to do anything. A numpy kron is a Blockwised kron with 2d core inputs, which right now users would have to create manually. For this PR you can decide to support only Scipy kron or also a Blockwised Kron? I don't see why the axis is needed. Scipy mixes everything if I understand correctly, and Numpy only the last two axes. From the signature / absense of the Blockwise (and its signature) we should know which case it is. |
Numpy is not blockwise, it just multiplies all the axes together in a way that still respects commutativity with matrix operations. Here's a guide to all the outputs:
|
3513e78
to
d39fb5e
Compare
After #684 this should be ready to go. |
90e4b0d
to
aac3305
Compare
Things like this have happened in the past. My best hunch is a difference between individuals commits vs all files. Locally is doing one and in CI the other and mypy opinion changes accordingly. Or python version. Or something else |
Probably the eclipse for all I know. Mypy is something else. |
Nice! Proved trickier because of the scipy thing but not terrible in the end |
a7e0ea6
to
b6c9692
Compare
d32640e
to
06fa1c3
Compare
tests/compile/test_builders.py
Outdated
f = f - grad(pt_sum(f), y) | ||
f = f - grad(pt_sum(f), y) | ||
fn = function([x, y, z], f) | ||
g = grad(pt_sum(f), y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a new test for the error and leave this one unchanged? Or did this start failing? Looks like it shouldn't have?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it started failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm why? It was working before, may need to investigate or is it obvious?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It merits investigation. But it might be that the test was always flakey and now we're catching it. The duplicated input is one of the intermediate f
representations -- something like *4 Matrix<?,?>
. The true inputs x,y,z aren't the problem. I assume this is something to do with how the grad function works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Intermediate variables can be repeated just fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test doesn't have obvious double inputs though, what's going on there? Is it because the original f
is considered input to both grad(f.sum(), y)
and grad(grad(f.sum(), y).sum(), y)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me see what are the duplicate inputs. In any case you proved we are too restrictive with the OFG([a], [a])
example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The L_op
of the first grad call looks like:
pytensor.dprint(self.fgraph, print_fgraph_inputs=True)
→ *0-<Matrix(float64, shape=(?, ?))> [id A]
→ *1-<Matrix(float64, shape=(?, ?))> [id B]
→ *2-<Matrix(float64, shape=(?, ?))> [id C]
→ *3-<Matrix(float64, shape=(?, ?))> [id D]
→ *4-<Matrix(float64, shape=(?, ?))> [id E]
*4-<Matrix(float64, shape=(?, ?))> [id E]
Mul [id F] 0
├─ *4-<Matrix(float64, shape=(?, ?))> [id E]
└─ *2-<Matrix(float64, shape=(?, ?))> [id C]
Mul [id G] 1
├─ *4-<Matrix(float64, shape=(?, ?))> [id E]
└─ *1-<Matrix(float64, shape=(?, ?))> [id B]
You can see the last input 4*
(the output_grad) is also the first output (the input grad wrt to x is just the output grad). So when the second L_op is called on the first one, it will try to build an OFG
that takes as inputs 4*
twice, once from its inputs, and once from its outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with skipping the input check if on_unused_inputs == 'ignore'. I thought that was better than removing the error. I guess a real solution will be to check for outputs that are inputs when we build the OpFromGraph, and clone them if so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, but let's do that in a separate PR. Then we can get rid of this hacky params as well
float32 😍 |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #622 +/- ##
==========================================
+ Coverage 80.76% 80.78% +0.02%
==========================================
Files 162 162
Lines 46707 46757 +50
Branches 11422 11440 +18
==========================================
+ Hits 37723 37773 +50
Misses 6732 6732
Partials 2252 2252
|
fd69bed
to
736da90
Compare
I squashed this PR down to two commits, one for the actual rewrite and one for the changes that touched OpFromGraph (that hopefully we can revert down the road). After the CI runs again I'll merge. The codecov is failing because we don't have tests for the OpFromGraph error -- do you care? I can handle this in the next PR that will deal less hackishly with the duplicated inputs/outputs problem. |
Yeah we can merge as is |
I wouldn't remove the check later though, just not use kwargs. So a test is not a bad idea, just don't want to drag the PR too much more |
736da90
to
997285a
Compare
997285a
to
979132c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
979132c
to
c1da527
Compare
Adds rewrites to lift certain linear algebra ops (inv, pinv, cholesky) through "compositions" of matrices (block_diag, kron).
Kron is currently broken, so it's marked as a draft.
Description
Related Issue
Checklist
Type of change