Skip to content

Add rewrite to merge multiple SVD Ops with different settings #769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ pytensor-venv/
.vscode/
testing-report.html
coverage.xml
.coverage.*
.coverage.*
51 changes: 50 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@

from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
node_rewriter,
)
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
det,
inv,
kron,
pinv,
svd,
)
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
Expand Down Expand Up @@ -377,3 +382,47 @@
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([SVD])
def local_svd_uv_simplify(fgraph, node):
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
(x,) = node.inputs

Check warning on line 396 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L396

Added line #L396 was not covered by tests

if node.compute_uv:
# compute_uv=True returns [u, s, v].
# if at least u or v is used, no need to rewrite this node.
if (
fgraph.clients[node.outputs[0]] is not None
or fgraph.clients[node.outputs[2]] is not None
):
return

Check warning on line 405 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L405

Added line #L405 was not covered by tests

# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
# First, iterate to see if there is an SVD Op that can be reused.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue

Check warning on line 411 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L411

Added line #L411 was not covered by tests
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if not cl.op.core_op.compute_uv:
return {fgraph.clients[node.outputs[1]]: cl.outputs[0]}

Check warning on line 414 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L414

Added line #L414 was not covered by tests

# If no SVD reusable, return a new one.
return [svd(x, full_matrices=node.full_matrices, compute_uv=False)]

Check warning on line 417 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L417

Added line #L417 was not covered by tests

else:
# compute_uv=False returns [s].
# We want rewrite if there is another one with compute_uv=True.
# For this case, just reuse the `s` from the one with compute_uv=True.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue

Check warning on line 425 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L425

Added line #L425 was not covered by tests
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only want to do this if that other node is actually using the UV. If not we would actually want to replace that node by this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be taken care by the first half at that node turn. As this is a local rewrite applied to all SVD node, each node will have its turn.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if you don't want to handle that other node there's no reason to rewrite this node into it. In general it's better to do as few rewrites as possible as every time a rewrite succeeds all other candidate rewrites are rerun (until an Equilibrium is achieved and nothing changes anymore).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought I like your eager approach better, it's not readable. Since SVDs are rare we don't need to over optimize

return [cl.outputs[1]]

Check warning on line 428 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L428

Added line #L428 was not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ricardoV94. My understanding is like this: The SVD with compute_uv == False will return [s], while the one with compute_uv == True will return [u, s, v]. We want to rewrite when there are 2 SVD Ops using the same input in the graph with different compute_uv value. Let's take the specific example of 2 SVD Ops, svd_f which returns [s_f] and svd_t which returns [u_t, s_t, v_t]. Based on whether at least u_t or v_t is used (since we still have to calculate both even if we use just one of them for subsequent calculations), 1 of 2 rewrites can happen:

  • Case 1: If at least u_t or v_t is used: return [s_t] in place of [s_f].
  • Case 2: Else: return [s_f] in place of [s_t].
  • Case 3: Additionally, if there is just one SVD Op with compute_uv == True, but both u and v are not used, then it must be substituted with a new SVD Op with compute_uv == False.

Copy link
Member

@ricardoV94 ricardoV94 May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup that's it!. When you write down the updated rewrite feel free to add comments with as much explanation as you did here!

Copy link
Member

@ricardoV94 ricardoV94 May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could also be some weird cases where there are 3 SVDs, one with uv and full_matrices that actually doesn't use the uv, and one with uv and not full matrices that actually uses them (or vice-versa). In that case we could replace one for the other, but perhaps that's too much to worry and unlikely to happen. I don't see we ignoring this causing any bug. I am just raising attention to it so we don't accidentally rewrite a full-matrices into non full-matrices that are actually used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one return {fgraph.clients[node.outputs[1]]: cl.outputs[0]} is this the correct syntax?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that tells to replace the key by the value variable

Loading