From 567b8d31e29c7e32511ceba3b55f2e9c2673a8a5 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 2 Feb 2024 21:19:06 +0100 Subject: [PATCH 1/2] Add rewrite to lift linear algebra through certain linalg ops --- pytensor/compile/builders.py | 2 +- pytensor/tensor/nlinalg.py | 12 ++++- pytensor/tensor/rewriting/linalg.py | 74 ++++++++++++++++++++++++++- tests/tensor/rewriting/test_linalg.py | 59 ++++++++++++++++++++- tests/tensor/test_nlinalg.py | 8 +++ 5 files changed, 150 insertions(+), 5 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 594a1188de..39d42aa133 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -7,7 +7,7 @@ from typing import cast import pytensor.tensor as pt -from pytensor import function +from pytensor.compile.function import function from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.mode import optdb from pytensor.compile.sharedvalue import SharedVariable diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 7b8768a154..9d3f615dc8 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -7,6 +7,7 @@ from numpy.core.numeric import normalize_axis_tuple # type: ignore from pytensor import scalar as ps +from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op @@ -614,7 +615,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True): Returns ------- - U, V, D : matrices + U, V, D : matrices """ return Blockwise(SVD(full_matrices, compute_uv))(a) @@ -1011,6 +1012,12 @@ def tensorsolve(a, b, axes=None): return TensorSolve(axes)(a, b) +class KroneckerProduct(OpFromGraph): + """ + Wrapper Op for Kronecker graphs + """ + + def kron(a, b): """Kronecker product. @@ -1042,7 +1049,8 @@ def kron(a, b): out_shape = tuple(a.shape * b.shape) output_out_of_shape = a_reshaped * b_reshaped output_reshaped = output_out_of_shape.reshape(out_shape) - return output_reshaped + + return KroneckerProduct(inputs=[a, b], outputs=[output_reshaped])(a, b) __all__ = [ diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index ea83d9356a..cdb1e59101 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1,22 +1,35 @@ import logging +from collections.abc import Callable from typing import cast +from pytensor import Variable +from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.tensor.basic import TensorVariable, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod -from pytensor.tensor.nlinalg import MatrixInverse, det +from pytensor.tensor.nlinalg import ( + KroneckerProduct, + MatrixInverse, + MatrixPinv, + det, + inv, + kron, + pinv, +) from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, register_stabilize, ) from pytensor.tensor.slinalg import ( + BlockDiagonal, Cholesky, Solve, SolveBase, + block_diag, cholesky, solve, solve_triangular, @@ -305,3 +318,62 @@ def local_log_prod_sqr(fgraph, node): # TODO: have a reduction like prod and sum that simply # returns the sign of the prod multiplication. + + +@register_specialize +@node_rewriter([Blockwise]) +def local_lift_through_linalg( + fgraph: FunctionGraph, node: Apply +) -> list[Variable] | None: + """ + Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops + that join matrices (KroneckerProduct, BlockDiagonal). + + This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix + operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker + product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This + reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component + matrices. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + + # TODO: Simplify this if we end up Blockwising KroneckerProduct + if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv): + y = node.inputs[0] + outer_op = node.op + + if y.owner and ( + isinstance(y.owner.op, Blockwise) + and isinstance(y.owner.op.core_op, BlockDiagonal) + or isinstance(y.owner.op, KroneckerProduct) + ): + input_matrices = y.owner.inputs + + if isinstance(outer_op.core_op, MatrixInverse): + outer_f = cast(Callable, inv) + elif isinstance(outer_op.core_op, Cholesky): + outer_f = cast(Callable, cholesky) + elif isinstance(outer_op.core_op, MatrixPinv): + outer_f = cast(Callable, pinv) + else: + raise NotImplementedError # pragma: no cover + + inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices] + + if isinstance(y.owner.op, KroneckerProduct): + return [kron(*inner_matrices)] + elif isinstance(y.owner.op.core_op, BlockDiagonal): + return [block_diag(*inner_matrices)] + else: + raise NotImplementedError # pragma: no cover diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 54ee110f6d..1e9d6194db 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -14,9 +14,16 @@ from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import _allclose, dot, matmul -from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse +from pytensor.tensor.nlinalg import ( + Det, + KroneckerProduct, + MatrixInverse, + MatrixPinv, + matrix_inverse, +) from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.slinalg import ( + BlockDiagonal, Cholesky, Solve, SolveBase, @@ -333,3 +340,53 @@ def test_invalid_batched_a(self): ref_fn(test_a, test_b), rtol=1e-7 if config.floatX == "float64" else 1e-5, ) + + +@pytest.mark.parametrize( + "constructor", [pt.dmatrix, pt.tensor3], ids=["not_batched", "batched"] +) +@pytest.mark.parametrize( + "f_op, f", + [ + (MatrixInverse, pt.linalg.inv), + (Cholesky, pt.linalg.cholesky), + (MatrixPinv, pt.linalg.pinv), + ], + ids=["inv", "cholesky", "pinv"], +) +@pytest.mark.parametrize( + "g_op, g", + [(BlockDiagonal, pt.linalg.block_diag), (KroneckerProduct, pt.linalg.kron)], + ids=["block_diag", "kron"], +) +def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): + if pytensor.config.floatX.endswith("32"): + pytest.skip("Test is flaky at half precision") + + A, B = list(map(constructor, "ab")) + X = f(g(A, B)) + + f1 = pytensor.function( + [A, B], X, mode=get_default_mode().including("local_lift_through_linalg") + ) + f2 = pytensor.function( + [A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg") + ) + + all_apply_nodes = f1.maker.fgraph.apply_nodes + f_ops = [ + x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), f_op) + ] + g_ops = [ + x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), g_op) + ] + + assert len(f_ops) == 2 + assert len(g_ops) == 1 + + test_vals = [ + np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2) + ] + test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals] + + np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 3b5de5fcf2..4bb88c9bc4 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -590,6 +590,14 @@ def setup_method(self): self.op = kron super().setup_method() + def test_vec_vec_kron_raises(self): + x = vector() + y = vector() + with pytest.raises( + TypeError, match="kron: inputs dimensions must sum to 3 or more" + ): + kron(x, y) + @pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]) @pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]) def test_perform(self, shp0, shp1): From c1da5277a70f91f6aeaf54ef9873a1a0cdedb4a5 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 20 Apr 2024 23:20:17 +0200 Subject: [PATCH 2/2] Avoid duplicated inputs in KroneckerProduct OpFromGraph --- pytensor/compile/builders.py | 9 +++++++++ pytensor/tensor/nlinalg.py | 5 +++++ tests/compile/test_builders.py | 17 +++++++++++++++++ tests/tensor/test_slinalg.py | 4 ++-- 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 39d42aa133..9e82df0fe4 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -400,6 +400,15 @@ def __init__( Check :func:`pytensor.function` for more arguments, only works when not inline. """ + ignore_unused_inputs = kwargs.get("on_unused_input", False) == "ignore" + if not ignore_unused_inputs and len(inputs) != len(set(inputs)): + var_counts = {var: inputs.count(var) for var in inputs} + duplicated_inputs = [var for var, count in var_counts.items() if count > 1] + raise ValueError( + f"There following variables were provided more than once as inputs to the OpFromGraph, resulting in an " + f"invalid graph: {duplicated_inputs}. Use dummy variables or var.copy() to distinguish " + f"variables when creating the OpFromGraph graph." + ) if not (isinstance(inputs, list) and isinstance(outputs, list)): raise TypeError("Inputs and outputs must be lists") diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 9d3f615dc8..85215fbe06 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -1034,6 +1034,11 @@ def kron(a, b): """ a = as_tensor_variable(a) b = as_tensor_variable(b) + + if a is b: + # In case a is the same as b, we need a different variable to build the OFG + b = a.copy() + if a.ndim + b.ndim <= 2: raise TypeError( "kron: inputs dimensions must sum to 3 or more. " diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index c0ddf7e894..6f8e8035d1 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -118,6 +118,7 @@ def test_grad_grad(self, cls_ofg): f = op(x, y, z) f = f - grad(pt_sum(f), y) f = f - grad(pt_sum(f), y) + fn = function([x, y, z], f) xv = np.ones((2, 2), dtype=config.floatX) yv = np.ones((2, 2), dtype=config.floatX) * 3 @@ -584,6 +585,22 @@ def test_explicit_input_from_shared(self): out = test_ofg(y, y) assert out.eval() == 4 + def test_repeated_inputs(self): + x = pt.dscalar("x") + y = pt.dscalar("y") + + with pytest.raises( + ValueError, + match="There following variables were provided more than once as inputs to the " + "OpFromGraph", + ): + OpFromGraph([x, x, y], [x + y]) + + # Test that repeated inputs will be allowed if unused inputs are ignored + g = OpFromGraph([x, x, y], [x + y], on_unused_input="ignore") + f = g(x, x, y) + assert f.eval({x: 5, y: 5}) == 10 + @config.change_flags(floatX="float64") def test_debugprint(): diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index d39c370ed3..e468b56e84 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -514,8 +514,8 @@ def test_expm_grad_3(): def test_solve_discrete_lyapunov_via_direct_real(): N = 5 rng = np.random.default_rng(utt.fetch_seed()) - a = pt.dmatrix() - q = pt.dmatrix() + a = pt.dmatrix("a") + q = pt.dmatrix("q") f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")]) A = rng.normal(size=(N, N))