-
Notifications
You must be signed in to change notification settings - Fork 129
Rewrite determinant of diagonal matrix as product of diagonal #797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The big missing features in this PR is detecting diagonal matrices that came from
Produces:
@ricardoV94 what do you think of wrapping |
We can wrap it but this actually doesn't look that crazy? It's a The add zero will be removed (why is it there though?) This can probably even be detected with PatternSubstitution (or whatever it's called) |
I assumed the add zero was broadcasting If we can do it with a pattern replace that'd be cool. @tanish1729 was already looking into those. I was worried that it would be hard to distinguish this specific graph from other arbitrary user allocations. If you don't think so, let's see how far we can go without doing extra work. |
The allocations are easy, because the first argument (the value) must be zero. Doesn't matter what shape it was otherwise. |
The issue with PatternSubstitution is that it will only work for the specific number of dims you use in the pattern. But even in traditional rewrite I think this is feasible |
We couldn't use the shape function in the pattern rewriter? Oh nvm I misread your comment. I see your point now. I'll say that in general I think using OpFromGraph to represent basic operations like |
dprint agree, but the counterargument to OFG is that the pattern we're matching would actually be more general than |
The graph literally says that :) |
Says, not screams :P |
I tried this to refresh my mind how the PatternNodeRewriter works: import pytensor
import pytensor.tensor as pt
from pytensor.graph.basic import equal_computations
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.utils import rewrite_graph
set_subtensor = pt.subtensor.advanced_set_subtensor
arange = pt.basic.ARange("int64")
rewrite = PatternNodeRewriter(
(pt.linalg.det, (set_subtensor, (pt.alloc, 0, "sh1", "sh2"), "x", (arange, 0, "stop", 1), (arange, 0, "stop", 1))),
(pt.prod, "x"),
name="determinant_of_diagonal",
allow_multiple_clients=True,
)
register_canonicalize(rewrite)
x = pt.vector("x")
out = pt.linalg.det(pt.diag(x))
new_out = rewrite_graph(out)
assert equal_computations([new_out], [pt.prod(x)]) |
I guess the stop and alloc must be the same shape of x, otherwise the determinant would be zeros? |
Yea det is undefined for non-square matrix |
Here is the updated version: import pytensor
import pytensor.tensor as pt
from pytensor.graph.basic import equal_computations
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor.shape import Shape_i
set_subtensor = pt.subtensor.advanced_set_subtensor
arange = pt.basic.ARange("int64")
x_shape = (Shape_i(0), "x")
rewrite = PatternNodeRewriter(
(pt.linalg.det, (set_subtensor, (pt.alloc, 0, x_shape, x_shape), "x", (arange, 0, x_shape, 1), (arange, 0, x_shape, 1))),
(pt.prod, "x"),
name="determinant_of_diagonal",
allow_multiple_clients=True,
)
register_canonicalize(rewrite)
x = pt.vector("x")
out = pt.linalg.det(pt.diag(x))
new_out = rewrite_graph(out, include=("ShapeOpt", "canonicalize"))
assert equal_computations([new_out], [pt.prod(x)]) |
Unrelated, and not needed for this PR, but would be nice for pattern rewrites to be able to enforce the type of x, specially ndim |
It looks great @tanish1729, only question left is whether you are handling correctly eye that put 1 off the diagonal (k != 0)? |
Co-authored-by: Ricardo Vieira <[email protected]>
…nsor into det-diag-rewrite added a suggestion using the github suggestion thingy
@ricardoV94 i wasnt sure about how to get the original type of det val so i did what made most sense logically. let me know if there's a better more correct way for this |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small typo in the comment, looks great otherwise
Description
Related Issue
Checklist
Type of change