-
Notifications
You must be signed in to change notification settings - Fork 132
Added inv(diag) -> (1/diag) rewrite for eye mul #860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@register_canonicalize | ||
@register_stabilize | ||
@node_rewriter([Blockwise]) | ||
def rewrite_inv_for_diag_and_orthonormal(fgraph, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imo this should be split into two rewrites
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay i'll do that
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], | ||
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], | ||
) | ||
def test_inv_diag_from_eye_mul(shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test for orthonormal matrix?
isinstance(core_op, inv) | ||
or isinstance(core_op, pinv) | ||
or isinstance(core_op, solve) | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are other solve Ops that could represent an inversion, notable solve_triangular
return None | ||
|
||
# Dealing with direct inverse Ops | ||
if isinstance(core_op, inv) or isinstance(core_op, pinv): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(core_op, inv) or isinstance(core_op, pinv): | |
if isinstance(core_op, inv | pinv): |
return [inverse] | ||
|
||
# Dealing with solve(x, eye) | ||
elif isinstance(core_op, solve): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the solve logic is going to be the same as the inv/pinv, except that you just have to check that B is eye, then extract the A argument. It probably won't be best to split it out this way
orthonormal_input = input_to_inv | ||
inverse = orthonormal_input.T | ||
return [inverse] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
orthonormal_input = input_to_inv | |
inverse = orthonormal_input.T | |
return [inverse] | |
return [input_to_inv.T] |
): | ||
return None | ||
|
||
# To make sure input is orthonormal, we have to check that its not S |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# To make sure input is orthonormal, we have to check that its not S | |
# To make sure input is orthonormal, we have to check that its not S (output order is U, S, Vh, so S is index 1) |
if len(non_eye_inputs) != 1: | ||
return None | ||
|
||
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] | |
eye, non_eye = eye_input[0], non_eye_inputs[0] |
Note sure what makes them "useful"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just the fact that im using them 😛 i'll update the names
@@ -377,3 +379,119 @@ def local_lift_through_linalg( | |||
return [block_diag(*inner_matrices)] | |||
else: | |||
raise NotImplementedError # pragma: no cover | |||
|
|||
|
|||
def _find_diag_from_eye_mul(potential_mul_input): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this duplicated from the other PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea. that one isnt merged yet so i had to add this function in this one too
Looks like you also need to rebase from main |
…nsor into inv-diag-rewrite conflicts fixed
@tanish1729, you should be able to rebase using the command |
hi! yep i figured that out and was doing that but a lot of things (plenty of conflicts) got confusing so i decided to start over. |
Description
Adds the rewrite for inv(diag) -> 1/diag
Related Issue
Checklist
Type of change