|
22 | 22 | inv,
|
23 | 23 | kron,
|
24 | 24 | pinv,
|
| 25 | + svd, |
25 | 26 | )
|
26 | 27 | from pytensor.tensor.rewriting.basic import (
|
27 | 28 | register_canonicalize,
|
@@ -393,14 +394,35 @@ def local_svd_uv_simplify(fgraph, node):
|
393 | 394 | and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
|
394 | 395 | """
|
395 | 396 | (x,) = node.inputs
|
396 |
| - compute_uv = False |
397 | 397 |
|
398 |
| - for cl, _ in fgraph.clients[x]: |
399 |
| - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): |
400 |
| - if (not compute_uv) and cl.op.core_op.compute_uv: |
401 |
| - compute_uv = True |
402 |
| - break |
403 |
| - |
404 |
| - if compute_uv and not node.op.compute_uv: |
405 |
| - full_matrices = node.op.full_matrices |
406 |
| - return [SVD(full_matrices=full_matrices, compute_uv=compute_uv)] |
| 398 | + if node.compute_uv: |
| 399 | + # compute_uv=True returns [u, s, v]. |
| 400 | + # if at least u or v is used, no need to rewrite this node. |
| 401 | + if ( |
| 402 | + fgraph.clients[node.outputs[0]] is not None |
| 403 | + or fgraph.clients[node.outputs[2]] is not None |
| 404 | + ): |
| 405 | + return |
| 406 | + |
| 407 | + # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. |
| 408 | + # First, iterate to see if there is an SVD Op that can be reused. |
| 409 | + for cl, _ in fgraph.clients[x]: |
| 410 | + if cl == "output": |
| 411 | + continue |
| 412 | + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): |
| 413 | + if not cl.op.core_op.compute_uv: |
| 414 | + return {fgraph.clients[node.outputs[1]]: cl.outputs[0]} |
| 415 | + |
| 416 | + # If no SVD reusable, return a new one. |
| 417 | + return [svd(x, full_matrices=node.full_matrices, compute_uv=False)] |
| 418 | + |
| 419 | + else: |
| 420 | + # compute_uv=False returns [s]. |
| 421 | + # We want rewrite if there is another one with compute_uv=True. |
| 422 | + # For this case, just reuse the `s` from the one with compute_uv=True. |
| 423 | + for cl, _ in fgraph.clients[x]: |
| 424 | + if cl == "output": |
| 425 | + continue |
| 426 | + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): |
| 427 | + if cl.op.core_op.compute_uv: |
| 428 | + return [cl.outputs[1]] |
0 commit comments