Skip to content

Add rewrite to lift linear algebra through certain linalg ops #622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import cast

import pytensor.tensor as pt
from pytensor import function
from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.mode import optdb
from pytensor.compile.sharedvalue import SharedVariable
Expand Down Expand Up @@ -400,6 +400,15 @@ def __init__(
Check :func:`pytensor.function` for more arguments, only works when not
inline.
"""
ignore_unused_inputs = kwargs.get("on_unused_input", False) == "ignore"
if not ignore_unused_inputs and len(inputs) != len(set(inputs)):
var_counts = {var: inputs.count(var) for var in inputs}
duplicated_inputs = [var for var, count in var_counts.items() if count > 1]
raise ValueError(
f"There following variables were provided more than once as inputs to the OpFromGraph, resulting in an "
f"invalid graph: {duplicated_inputs}. Use dummy variables or var.copy() to distinguish "
f"variables when creating the OpFromGraph graph."
)

if not (isinstance(inputs, list) and isinstance(outputs, list)):
raise TypeError("Inputs and outputs must be lists")
Expand Down
17 changes: 15 additions & 2 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from numpy.core.numeric import normalize_axis_tuple # type: ignore

from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
Expand Down Expand Up @@ -614,7 +615,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True):

Returns
-------
U, V, D : matrices
U, V, D : matrices

"""
return Blockwise(SVD(full_matrices, compute_uv))(a)
Expand Down Expand Up @@ -1011,6 +1012,12 @@ def tensorsolve(a, b, axes=None):
return TensorSolve(axes)(a, b)


class KroneckerProduct(OpFromGraph):
"""
Wrapper Op for Kronecker graphs
"""


def kron(a, b):
"""Kronecker product.

Expand All @@ -1027,6 +1034,11 @@ def kron(a, b):
"""
a = as_tensor_variable(a)
b = as_tensor_variable(b)

if a is b:
# In case a is the same as b, we need a different variable to build the OFG
b = a.copy()

if a.ndim + b.ndim <= 2:
raise TypeError(
"kron: inputs dimensions must sum to 3 or more. "
Expand All @@ -1042,7 +1054,8 @@ def kron(a, b):
out_shape = tuple(a.shape * b.shape)
output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = output_out_of_shape.reshape(out_shape)
return output_reshaped

return KroneckerProduct(inputs=[a, b], outputs=[output_reshaped])(a, b)


__all__ = [
Expand Down
74 changes: 73 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
import logging
from collections.abc import Callable
from typing import cast

from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.nlinalg import (
KroneckerProduct,
MatrixInverse,
MatrixPinv,
det,
inv,
kron,
pinv,
)
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
Solve,
SolveBase,
block_diag,
cholesky,
solve,
solve_triangular,
Expand Down Expand Up @@ -305,3 +318,62 @@ def local_log_prod_sqr(fgraph, node):

# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.


@register_specialize
@node_rewriter([Blockwise])
def local_lift_through_linalg(
fgraph: FunctionGraph, node: Apply
) -> list[Variable] | None:
"""
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
that join matrices (KroneckerProduct, BlockDiagonal).

This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix
operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker
product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This
reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component
matrices.

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""

# TODO: Simplify this if we end up Blockwising KroneckerProduct
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
y = node.inputs[0]
outer_op = node.op

if y.owner and (
isinstance(y.owner.op, Blockwise)
and isinstance(y.owner.op.core_op, BlockDiagonal)
or isinstance(y.owner.op, KroneckerProduct)
):
input_matrices = y.owner.inputs

if isinstance(outer_op.core_op, MatrixInverse):
outer_f = cast(Callable, inv)
elif isinstance(outer_op.core_op, Cholesky):
outer_f = cast(Callable, cholesky)
elif isinstance(outer_op.core_op, MatrixPinv):
outer_f = cast(Callable, pinv)
else:
raise NotImplementedError # pragma: no cover

inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices]

if isinstance(y.owner.op, KroneckerProduct):
return [kron(*inner_matrices)]
elif isinstance(y.owner.op.core_op, BlockDiagonal):
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover
17 changes: 17 additions & 0 deletions tests/compile/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def test_grad_grad(self, cls_ofg):
f = op(x, y, z)
f = f - grad(pt_sum(f), y)
f = f - grad(pt_sum(f), y)

fn = function([x, y, z], f)
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
Expand Down Expand Up @@ -584,6 +585,22 @@ def test_explicit_input_from_shared(self):
out = test_ofg(y, y)
assert out.eval() == 4

def test_repeated_inputs(self):
x = pt.dscalar("x")
y = pt.dscalar("y")

with pytest.raises(
ValueError,
match="There following variables were provided more than once as inputs to the "
"OpFromGraph",
):
OpFromGraph([x, x, y], [x + y])

# Test that repeated inputs will be allowed if unused inputs are ignored
g = OpFromGraph([x, x, y], [x + y], on_unused_input="ignore")
f = g(x, x, y)
assert f.eval({x: 5, y: 5}) == 10


@config.change_flags(floatX="float64")
def test_debugprint():
Expand Down
59 changes: 58 additions & 1 deletion tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.nlinalg import (
Det,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
matrix_inverse,
)
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
Solve,
SolveBase,
Expand Down Expand Up @@ -333,3 +340,53 @@ def test_invalid_batched_a(self):
ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
)


@pytest.mark.parametrize(
"constructor", [pt.dmatrix, pt.tensor3], ids=["not_batched", "batched"]
)
@pytest.mark.parametrize(
"f_op, f",
[
(MatrixInverse, pt.linalg.inv),
(Cholesky, pt.linalg.cholesky),
(MatrixPinv, pt.linalg.pinv),
],
ids=["inv", "cholesky", "pinv"],
)
@pytest.mark.parametrize(
"g_op, g",
[(BlockDiagonal, pt.linalg.block_diag), (KroneckerProduct, pt.linalg.kron)],
ids=["block_diag", "kron"],
)
def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
if pytensor.config.floatX.endswith("32"):
pytest.skip("Test is flaky at half precision")

A, B = list(map(constructor, "ab"))
X = f(g(A, B))

f1 = pytensor.function(
[A, B], X, mode=get_default_mode().including("local_lift_through_linalg")
)
f2 = pytensor.function(
[A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg")
)

all_apply_nodes = f1.maker.fgraph.apply_nodes
f_ops = [
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), f_op)
]
g_ops = [
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), g_op)
]

assert len(f_ops) == 2
assert len(g_ops) == 1

test_vals = [
np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)
]
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]

np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
8 changes: 8 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,14 @@ def setup_method(self):
self.op = kron
super().setup_method()

def test_vec_vec_kron_raises(self):
x = vector()
y = vector()
with pytest.raises(
TypeError, match="kron: inputs dimensions must sum to 3 or more"
):
kron(x, y)

@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
def test_perform(self, shp0, shp1):
Expand Down
4 changes: 2 additions & 2 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ def test_expm_grad_3():
def test_solve_discrete_lyapunov_via_direct_real():
N = 5
rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix()
q = pt.dmatrix()
a = pt.dmatrix("a")
q = pt.dmatrix("q")
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])

A = rng.normal(size=(N, N))
Expand Down
Loading