Skip to content

Add rewrite to merge multiple SVD Ops with different settings #769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ pytensor-venv/
testing-report.html
coverage.xml
.coverage.*
pics
34 changes: 33 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@

from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
node_rewriter,
)
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
Expand Down Expand Up @@ -377,3 +381,31 @@ def local_lift_through_linalg(
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([SVD])
def local_svd_uv_simplify(fgraph, node):
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
(x,) = node.inputs
svd_count = 0
compute_uv = False
not_compute_uv_svd_list = []

for cl, _ in fgraph.clients[x]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to be careful because if the output of the SVD is an output of the function one of the clients will be a string "output" and the call cl.op will fail.

if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
svd_count += 1
if (not compute_uv) and cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need that first check?

Suggested change
if (not compute_uv) and cl.op.core_op.compute_uv:
if cl.op.core_op.compute_uv:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check if the uv outputs of this node are actually used (i.e., they have clients of their own). If not, they are useless and the rewrite shouldn't happen. In fact, this or another rewrite should change the flag from True to False for those nodes

compute_uv = True
if not cl.op.core_op.compute_uv:
not_compute_uv_svd_list.append(cl)

if svd_count > 1 and compute_uv:
for cl in not_compute_uv_svd_list:
cl.op.core_op.compute_uv = True
return [cl.outputs[0] for cl in not_compute_uv_svd_list]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think changing properties of the op inplace might lead to problems...

This rewrite function should run for each SVD node, so maybe it is easier to just locate an existing compute_uv = True node, and return that as replacement for each compuet_uv = False node?

So something like:

  • If compute_uv is False, return and do nothing
  • check if there is a compute_uv = True node in the graph with the same input. If not, return and do nothing
  • Return the exising output of that node as replacement for the current compute_uv = False node.

I wonder though if there could be bad interactions somewhere if there is a rewrite that replaces compute_uv = Fales nodes if they are not used? We don't want to run into any infinite cycles...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 Do you know if there are any problems that could happen if a rewrite returns an existing variable instead of a new one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there will be a problem only when a rewrite tries to replace a variable by another that depends on the original variable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes we shouldn't modify the properties in place. We should replace the smaller Op by the bigger one, just make sure the smaller one is not in the ancestors of the bigger one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise creating a new SVD should be simple, just call the user facing constructor with the specific flags

Copy link
Contributor Author

@HangenYuu HangenYuu May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I seemed to dump information carelessly. The gist was

  1. I updated the code logic to be a node rewriter.
  2. The rewrite is registered properly in optdb. However, I am having trouble coming up with a test case to show the effect of the rewrite. Perhaps @jessegrabowski can provide the original use case that led to you opening the issue Add rewrite to merge multiple SVD Ops with different settings #732?

Copy link
Member

@jessegrabowski jessegrabowski May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will arise in gradient graphs. For example, you can just do:

X = pt.dmatrix('X')
s = pt.linalg.svd(X, compute_uv=False)
g = pt.grad(s.sum(), X)

The graph for g will re-compute the SVD of X during the backward pass with compute_uv = True, because we require the matrices U and V to compute the gradient of s with respect to X. Pytensor then won't be able to see that these two computations are the same, and will end up computing the SVD twice.

Copy link
Contributor Author

@HangenYuu HangenYuu May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a_pt = matrix("a")
s = svd(a_pt, full_matrices=False, compute_uv=False)
gs = pt.grad(pt.sum(s), a_pt)
f = pytensor.function([a_pt], gs)
e = pytensor.graph.fg.FunctionGraph([a_pt], [gs], clone=False)

Thank you. I indeed received a graph for gs and e with 2 different SVD:
image

But for f, I receive a graph with just a single SVD (that seems to be rewritten already with compute_uv=True):
image

The f's rewritten graph will be used in calculation if I run f([[1, 2], [3, 4]]). Does this satisfy your end goal already?

Copy link
Contributor Author

@HangenYuu HangenYuu May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is f summary profile:

Function profiling
==================
  Message: /tmp/ipykernel_1282122/871230895.py:10
  Time in 1 calls to Function.__call__: 3.448710e-02s
  Time in Function.vm.__call__: 0.03426380921155214s (99.353%)
  Time in thunks: 0.03424406051635742s (99.295%)
  Total compilation time: 4.109558e-02s
    Number of Apply nodes: 2
    PyTensor rewrite time: 2.893809e-02s
       PyTensor validate time: 2.457825e-04s
    PyTensor Linker time (includes C, CUDA code generation/compiling): 0.00876139895990491s
       C-cache preloading 5.506449e-03s
       Import time 8.061258e-04s
       Node make_thunk time 1.967770e-03s
           Node Dot22(SVD{full_matrices=False, compute_uv=True}.0, SVD{full_matrices=False, compute_uv=True}.2) time 1.942240e-03s
           Node SVD{full_matrices=False, compute_uv=True}(a) time 1.436425e-05s

Time in all call to pytensor.grad() 1.036228e-02s
Time since pytensor import 2.774s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  99.8%    99.8%       0.034s       3.42e-02s     Py       1       1   pytensor.tensor.nlinalg.SVD
   0.2%   100.0%       0.000s       6.60e-05s     C        1       1   pytensor.tensor.blas.Dot22
   ... (remaining 0 Classes account for   0.00%(0.00s) of the runtime)

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  99.8%    99.8%       0.034s       3.42e-02s     Py       1        1   SVD{full_matrices=False, compute_uv=True}
   0.2%   100.0%       0.000s       6.60e-05s     C        1        1   Dot22
   ... (remaining 0 Ops account for   0.00%(0.00s) of the runtime)

Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
  99.8%    99.8%       0.034s       3.42e-02s      1     0   SVD{full_matrices=False, compute_uv=True}(a)
   0.2%   100.0%       0.000s       6.60e-05s      1     1   Dot22(SVD{full_matrices=False, compute_uv=True}.0, SVD{full_matrices=False, compute_uv=True}.2)
   ... (remaining 0 Apply instances account for 0.00%(0.00s) of the runtime)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytensor.dprint may be an easier way to introspect the graphs

Loading