Skip to content

Added inv(diag) -> (1/diag) rewrite for eye mul #860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

tanish1729
Copy link
Contributor

@tanish1729 tanish1729 commented Jun 27, 2024

Description

Adds the rewrite for inv(diag) -> 1/diag

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@tanish1729
Copy link
Contributor Author

@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_for_diag_and_orthonormal(fgraph, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo this should be split into two rewrites

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay i'll do that

[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
)
def test_inv_diag_from_eye_mul(shape):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test for orthonormal matrix?

isinstance(core_op, inv)
or isinstance(core_op, pinv)
or isinstance(core_op, solve)
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other solve Ops that could represent an inversion, notable solve_triangular

return None

# Dealing with direct inverse Ops
if isinstance(core_op, inv) or isinstance(core_op, pinv):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(core_op, inv) or isinstance(core_op, pinv):
if isinstance(core_op, inv | pinv):

return [inverse]

# Dealing with solve(x, eye)
elif isinstance(core_op, solve):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the solve logic is going to be the same as the inv/pinv, except that you just have to check that B is eye, then extract the A argument. It probably won't be best to split it out this way

Comment on lines +491 to +493
orthonormal_input = input_to_inv
inverse = orthonormal_input.T
return [inverse]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
orthonormal_input = input_to_inv
inverse = orthonormal_input.T
return [inverse]
return [input_to_inv.T]

):
return None

# To make sure input is orthonormal, we have to check that its not S
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# To make sure input is orthonormal, we have to check that its not S
# To make sure input is orthonormal, we have to check that its not S (output order is U, S, Vh, so S is index 1)

if len(non_eye_inputs) != 1:
return None

useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
eye, non_eye = eye_input[0], non_eye_inputs[0]

Note sure what makes them "useful"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just the fact that im using them 😛 i'll update the names

@@ -377,3 +379,119 @@ def local_lift_through_linalg(
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover


def _find_diag_from_eye_mul(potential_mul_input):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this duplicated from the other PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea. that one isnt merged yet so i had to add this function in this one too

@jessegrabowski
Copy link
Member

Looks like you also need to rebase from main

@tanish1729 tanish1729 closed this Jul 3, 2024
@tanish1729 tanish1729 deleted the inv-diag-rewrite branch July 3, 2024 13:08
@maresb
Copy link
Contributor

maresb commented Jul 3, 2024

@tanish1729, you should be able to rebase using the command git rebase main to avoid deleting your branch. This way everything can stay within the same PR.

@tanish1729
Copy link
Contributor Author

hi! yep i figured that out and was doing that but a lot of things (plenty of conflicts) got confusing so i decided to start over.
i did eventually figure out rebasing and stuff by trying it out so i'll be sure to use that from now. thanks :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants