Skip to content

Commit 55ad931

Browse files
committed
Refactored logic for SVD to support all 3 cases
1 parent 14f89f8 commit 55ad931

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
inv,
2323
kron,
2424
pinv,
25+
svd,
2526
)
2627
from pytensor.tensor.rewriting.basic import (
2728
register_canonicalize,
@@ -393,14 +394,35 @@ def local_svd_uv_simplify(fgraph, node):
393394
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
394395
"""
395396
(x,) = node.inputs
396-
compute_uv = False
397397

398-
for cl, _ in fgraph.clients[x]:
399-
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
400-
if (not compute_uv) and cl.op.core_op.compute_uv:
401-
compute_uv = True
402-
break
403-
404-
if compute_uv and not node.op.compute_uv:
405-
full_matrices = node.op.full_matrices
406-
return [SVD(full_matrices=full_matrices, compute_uv=compute_uv)]
398+
if node.compute_uv:
399+
# compute_uv=True returns [u, s, v].
400+
# if at least u or v is used, no need to rewrite this node.
401+
if (
402+
fgraph.clients[node.outputs[0]] is not None
403+
or fgraph.clients[node.outputs[2]] is not None
404+
):
405+
return
406+
407+
# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
408+
# First, iterate to see if there is an SVD Op that can be reused.
409+
for cl, _ in fgraph.clients[x]:
410+
if cl == "output":
411+
continue
412+
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
413+
if not cl.op.core_op.compute_uv:
414+
return {fgraph.clients[node.outputs[1]]: cl.outputs[0]}
415+
416+
# If no SVD reusable, return a new one.
417+
return [svd(x, full_matrices=node.full_matrices, compute_uv=False)]
418+
419+
else:
420+
# compute_uv=False returns [s].
421+
# We want rewrite if there is another one with compute_uv=True.
422+
# For this case, just reuse the `s` from the one with compute_uv=True.
423+
for cl, _ in fgraph.clients[x]:
424+
if cl == "output":
425+
continue
426+
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
427+
if cl.op.core_op.compute_uv:
428+
return [cl.outputs[1]]

0 commit comments

Comments
 (0)