diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cdb1e59101..30d9084449 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -4,13 +4,17 @@ from pytensor import Variable from pytensor.graph import Apply, FunctionGraph -from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter +from pytensor.graph.rewriting.basic import ( + copy_stack_trace, + node_rewriter, +) from pytensor.tensor.basic import TensorVariable, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.nlinalg import ( + SVD, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -18,6 +22,7 @@ inv, kron, pinv, + svd, ) from pytensor.tensor.rewriting.basic import ( register_canonicalize, @@ -377,3 +382,59 @@ def local_lift_through_linalg( return [block_diag(*inner_matrices)] else: raise NotImplementedError # pragma: no cover + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([Blockwise]) +def svd_uv_merge(fgraph, node): + """If we have more than one `SVD` `Op`s and at least one has keyword argument + `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere + and allow `pytensor` to re-use the decomposition outputs instead of recomputing. + """ + if not isinstance(node.op.core_op, SVD): + return + + (x,) = node.inputs + + if node.op.core_op.compute_uv: + # compute_uv=True returns [u, s, v]. + # if at least u or v is used, no need to rewrite this node. + if ( + len(fgraph.clients[node.outputs[0]]) > 0 + or len(fgraph.clients[node.outputs[2]]) > 0 + ): + return + + # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. + # First, iterate to see if there is an SVD Op that can be reused. + for cl, _ in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if not cl.op.core_op.compute_uv: + return { + node.outputs[1]: cl.outputs[0], + } + + # If no SVD reusable, return a new one. + return { + node.outputs[1]: svd( + x, full_matrices=node.op.core_op.full_matrices, compute_uv=False + ), + } + + else: + # compute_uv=False returns [s]. + # We want rewrite if there is another one with compute_uv=True. + # For this case, just reuse the `s` from the one with compute_uv=True. + for cl, _ in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if cl.op.core_op.compute_uv and ( + len(fgraph.clients[cl.outputs[0]]) > 0 + or len(fgraph.clients[cl.outputs[2]]) > 0 + ): + return [cl.outputs[1]] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 1e9d6194db..523742e356 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -15,11 +15,13 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import _allclose, dot, matmul from pytensor.tensor.nlinalg import ( + SVD, Det, KroneckerProduct, MatrixInverse, MatrixPinv, matrix_inverse, + svd, ) from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.slinalg import ( @@ -390,3 +392,67 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals] np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) + + +def test_svd_uv_merge(): + a = matrix("a") + s_1 = svd(a, full_matrices=False, compute_uv=False) + _, s_2, _ = svd(a, full_matrices=False, compute_uv=True) + _, s_3, _ = svd(a, full_matrices=True, compute_uv=True) + u_4, s_4, v_4 = svd(a, full_matrices=True, compute_uv=True) + # `grad` will introduces an SVD Op with compute_uv=True + # full_matrices = True is not supported for grad of svd + gs = pt.grad(pt.sum(s_1), a) + + # 1. compute_uv=False needs rewriting with compute_uv=True + f_1 = pytensor.function([a], gs) + nodes = f_1.maker.fgraph.apply_nodes + svd_counter = 0 + for node in nodes: + if isinstance(node.op, SVD): + assert node.op.compute_uv + svd_counter += 1 + assert svd_counter == 1 + + # 2. compute_uv=True needs rewriting with compute=False, reuse node + f_2 = pytensor.function([a], [s_1, s_2]) + nodes = f_2.maker.fgraph.apply_nodes + svd_counter = 0 + for node in nodes: + if isinstance(node.op, SVD): + assert not node.op.compute_uv + svd_counter += 1 + assert svd_counter == 1 + + # 3. compute_uv=True needs rewriting with compute=False, create new node + # full_matrices needs to retain the value + f_3 = pytensor.function([a], [s_2]) + nodes = f_3.maker.fgraph.apply_nodes + svd_counter = 0 + for node in nodes: + if isinstance(node.op, SVD): + assert not node.op.compute_uv + svd_counter += 1 + assert svd_counter == 1 + + # Case 2 of 3. for a different full_matrices + f_4 = pytensor.function([a], [s_3]) + nodes = f_4.maker.fgraph.apply_nodes + svd_counter = 0 + for node in nodes: + if isinstance(node.op, SVD): + assert not node.op.compute_uv + assert node.op.full_matrices + svd_counter += 1 + assert svd_counter == 1 + + # 4. No rewrite should happen + f_5 = pytensor.function([a], [u_4]) + nodes = f_5.maker.fgraph.apply_nodes + svd_counter = 0 + for node in nodes: + if isinstance(node.op, SVD): + assert node.op.full_matrices + assert node.op.compute_uv + svd_counter += 1 + assert svd_counter == 1