Skip to content

Rewrite determinant of diagonal matrix as product of diagonal #797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jul 3, 2024

Conversation

tanish1729
Copy link
Contributor

Description

  • Implemented rewrite for determinant of a diagonal matrix

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@jessegrabowski
Copy link
Member

The big missing features in this PR is detecting diagonal matrices that came from pt.diag(x). These produce quite a complex graph, so it's not so trivial to find them via graph inspection. For example:

x = pt.dvector('x')
x_diag = pt.diag(x)

x_diag.dprint()

Produces:

AdvancedSetSubtensor [id A]
 ├─ Alloc [id B]
 │  ├─ 0.0 [id C]
 │  ├─ Add [id D]
 │  │  ├─ Subtensor{i} [id E]
 │  │  │  ├─ Shape [id F]
 │  │  │  │  └─ x [id G]
 │  │  │  └─ -1 [id H]
 │  │  └─ 0 [id I]
 │  └─ Add [id D]
 │     └─ ···
 ├─ x [id G]
 ├─ Add [id J]
 │  ├─ ARange{dtype='int64'} [id K]
 │  │  ├─ 0 [id L]
 │  │  ├─ Subtensor{i} [id M]
 │  │  │  ├─ Shape [id N]
 │  │  │  │  └─ x [id G]
 │  │  │  └─ -1 [id O]
 │  │  └─ 1 [id P]
 │  └─ ExpandDims{axis=0} [id Q]
 │     └─ 0 [id R]
 └─ Add [id S]
    ├─ ARange{dtype='int64'} [id K]
    │  └─ ···
    └─ ExpandDims{axis=0} [id T]
       └─ 0 [id U]

@ricardoV94 what do you think of wrapping pt.diag in an OpFromGraph that would make it easier for this rewrite to find? Then we would just need to look for a Diagonal Op, rather than this subgraph. It would also pretty up the dprint for graphs that use pt.diag quite a bit.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 7, 2024

We can wrap it but this actually doesn't look that crazy? It's a set_subtensor(zeros(m, n), x, arange(k), arange(k)). The m, n don't matter nor k, as long as they're the same.

The add zero will be removed (why is it there though?)

This can probably even be detected with PatternSubstitution (or whatever it's called)

@jessegrabowski
Copy link
Member

I assumed the add zero was broadcasting

If we can do it with a pattern replace that'd be cool. @tanish1729 was already looking into those. I was worried that it would be hard to distinguish this specific graph from other arbitrary user allocations. If you don't think so, let's see how far we can go without doing extra work.

@ricardoV94
Copy link
Member

I assumed the add zero was broadcasting

If we can do it with a pattern replace that'd be cool. @tanish1729 was already looking into those. I was worried that it would be hard to distinguish this specific graph from other arbitrary user allocations. If you don't think so, let's see how far we can go without doing extra work.

The allocations are easy, because the first argument (the value) must be zero. Doesn't matter what shape it was otherwise.

@ricardoV94
Copy link
Member

The issue with PatternSubstitution is that it will only work for the specific number of dims you use in the pattern. But even in traditional rewrite I think this is feasible

@jessegrabowski
Copy link
Member

jessegrabowski commented Jun 7, 2024

We couldn't use the shape function in the pattern rewriter?

Oh nvm I misread your comment. I see your point now.

I'll say that in general I think using OpFromGraph to represent basic operations like pt.diag is a good idea. Nothing about the graph generated by pt.diag screams "this is setting the diagonal of the zero matrix to the vector x". Especially if this was embedded in a larger graph, I think there's value add to hiding some of these details in the dprint.

@ricardoV94
Copy link
Member

I'll say that in general I think using OpFromGraph to represent basic operations like pt.diag is a good idea. Nothing about the graph generated by pt.diag screams "this is setting the diagonal of the zero matrix to the vector x". Especially if this was embedded in a larger graph, I think there's value add to hiding some of these details in the dprint.

dprint agree, but the counterargument to OFG is that the pattern we're matching would actually be more general than diagonal. Whether it happens in practice or not I can't tell, but once you wrote the code it would work regardless. It may also be a good practice :)

@ricardoV94
Copy link
Member

Nothing about the graph generated by pt.diag screams "this is setting the diagonal of the zero matrix to the vector x". Especially if this was embedded in a larger graph, I think there's value add to hiding some of these details in the dprint.

The graph literally says that :)

@jessegrabowski
Copy link
Member

Says, not screams :P

@ricardoV94
Copy link
Member

I tried this to refresh my mind how the PatternNodeRewriter works:

import pytensor
import pytensor.tensor as pt

from pytensor.graph.basic import equal_computations
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.utils import rewrite_graph 

set_subtensor = pt.subtensor.advanced_set_subtensor
arange = pt.basic.ARange("int64")
rewrite = PatternNodeRewriter(
    (pt.linalg.det, (set_subtensor, (pt.alloc, 0, "sh1", "sh2"), "x", (arange, 0, "stop", 1), (arange, 0, "stop", 1))),
    (pt.prod, "x"),
    name="determinant_of_diagonal",
    allow_multiple_clients=True,
)
register_canonicalize(rewrite)

x = pt.vector("x")
out = pt.linalg.det(pt.diag(x))

new_out = rewrite_graph(out)
assert equal_computations([new_out], [pt.prod(x)])

@ricardoV94
Copy link
Member

I guess the stop and alloc must be the same shape of x, otherwise the determinant would be zeros?

@jessegrabowski
Copy link
Member

Yea det is undefined for non-square matrix

@ricardoV94
Copy link
Member

Here is the updated version:

import pytensor
import pytensor.tensor as pt

from pytensor.graph.basic import equal_computations
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.graph.rewriting.basic import PatternNodeRewriter
from pytensor.graph.rewriting.utils import rewrite_graph 
from pytensor.tensor.shape import Shape_i

set_subtensor = pt.subtensor.advanced_set_subtensor
arange = pt.basic.ARange("int64")
x_shape = (Shape_i(0), "x")
rewrite = PatternNodeRewriter(
    (pt.linalg.det, (set_subtensor, (pt.alloc, 0, x_shape, x_shape), "x", (arange, 0, x_shape, 1), (arange, 0, x_shape, 1))),
    (pt.prod, "x"),
    name="determinant_of_diagonal",
    allow_multiple_clients=True,
)
register_canonicalize(rewrite)

x = pt.vector("x")
out = pt.linalg.det(pt.diag(x))
new_out = rewrite_graph(out, include=("ShapeOpt", "canonicalize"))
assert equal_computations([new_out], [pt.prod(x)])

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 7, 2024

Unrelated, and not needed for this PR, but would be nice for pattern rewrites to be able to enforce the type of x, specially ndim

@ricardoV94
Copy link
Member

It looks great @tanish1729, only question left is whether you are handling correctly eye that put 1 off the diagonal (k != 0)?

…nsor into det-diag-rewrite

added a suggestion using the github suggestion thingy
@tanish1729
Copy link
Contributor Author

@ricardoV94 i wasnt sure about how to get the original type of det val so i did what made most sense logically. let me know if there's a better more correct way for this

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 26, 2024

@ricardoV94 i wasnt sure about how to get the original type of det val so i did what made most sense logically. let me know if there's a better more correct way for this

node.outputs[0].type.dtype. The node is the one of the Det Op passed to your rewrite function

@ricardoV94 ricardoV94 added graph rewriting linalg Linear algebra enhancement New feature or request labels Jun 26, 2024
@ricardoV94 ricardoV94 changed the title Added det-diag rewrite Rewrite determinant of diagonal matrix Jun 26, 2024
@ricardoV94 ricardoV94 changed the title Rewrite determinant of diagonal matrix Rewrite determinant of diagonal matrix as product of diagonal Jun 26, 2024
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo in the comment, looks great otherwise

@jessegrabowski jessegrabowski merged commit 94e9ef0 into pymc-devs:main Jul 3, 2024
56 of 57 checks passed
@tanish1729 tanish1729 deleted the det-diag-rewrite branch July 3, 2024 13:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants