Skip to content

Commit fffb84c

Browse files
Cleanup for Optimal Control Ops (#1045)
* Blockwise optimal linear control ops * Add jax rewrite to eliminate `BilinearSolveDiscreteLyapunov` * set `solve_discrete_lyapunov` method default to bilinear * Appease mypy * restore method dispatching * Use `pt.vectorize` on base `solve_discrete_lyapunov` case * Apply JAX rewrite before canonicalization * Improve tests * Remove useless warning filters * Fix local_blockwise_alloc rewrite The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left. * Fix float32 tests * Test against complex inputs * Appease ViPy (Vieira-py type checking) * Remove condition from `TensorLike` import * Infer dtype from `node.outputs.type.dtype` * Remove unused mypy ignore * Don't manually set dtype of output Revert change to `_solve_discrete_lyapunov` * Set dtype of Op outputs --------- Co-authored-by: ricardoV94 <[email protected]>
1 parent dae731d commit fffb84c

File tree

5 files changed

+301
-124
lines changed

5 files changed

+301
-124
lines changed

pytensor/tensor/rewriting/blockwise.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
127127
value, *shape = inp.owner.inputs
128128

129129
# Check what to do with the value of the Alloc
130-
squeezed_value = _squeeze_left(value, batch_ndim)
131-
missing_ndim = len(shape) - value.type.ndim
130+
missing_ndim = inp.type.ndim - value.type.ndim
131+
squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
132132
if (
133133
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
134134
!= inp.type.broadcastable[batch_ndim:]

pytensor/tensor/rewriting/linalg.py

+23
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
from pytensor import Variable
66
from pytensor import tensor as pt
7+
from pytensor.compile import optdb
78
from pytensor.graph import Apply, FunctionGraph
89
from pytensor.graph.rewriting.basic import (
910
copy_stack_trace,
11+
in2out,
1012
node_rewriter,
1113
)
1214
from pytensor.scalar.basic import Mul
@@ -45,9 +47,11 @@
4547
Cholesky,
4648
Solve,
4749
SolveBase,
50+
_bilinear_solve_discrete_lyapunov,
4851
block_diag,
4952
cholesky,
5053
solve,
54+
solve_discrete_lyapunov,
5155
solve_triangular,
5256
)
5357

@@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
966970
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
967971

968972
return [eye_input * (non_eye_input**0.5)]
973+
974+
975+
@node_rewriter([_bilinear_solve_discrete_lyapunov])
976+
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
977+
"""
978+
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
979+
"""
980+
A, B = (cast(TensorVariable, x) for x in node.inputs)
981+
result = solve_discrete_lyapunov(A, B, method="direct")
982+
983+
return [result]
984+
985+
986+
optdb.register(
987+
"jax_bilinaer_lyapunov_to_direct",
988+
in2out(jax_bilinaer_lyapunov_to_direct),
989+
"jax",
990+
position=0.9, # Run before canonicalization
991+
)

pytensor/tensor/slinalg.py

+127-67
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import typing
33
import warnings
44
from functools import reduce
5-
from typing import TYPE_CHECKING, Literal, cast
5+
from typing import Literal, cast
66

77
import numpy as np
88
import scipy.linalg
@@ -11,7 +11,7 @@
1111
import pytensor.tensor as pt
1212
from pytensor.graph.basic import Apply
1313
from pytensor.graph.op import Op
14-
from pytensor.tensor import as_tensor_variable
14+
from pytensor.tensor import TensorLike, as_tensor_variable
1515
from pytensor.tensor import basic as ptb
1616
from pytensor.tensor import math as ptm
1717
from pytensor.tensor.blockwise import Blockwise
@@ -21,9 +21,6 @@
2121
from pytensor.tensor.variable import TensorVariable
2222

2323

24-
if TYPE_CHECKING:
25-
from pytensor.tensor import TensorLike
26-
2724
logger = logging.getLogger(__name__)
2825

2926

@@ -777,7 +774,16 @@ def perform(self, node, inputs, outputs):
777774

778775

779776
class SolveContinuousLyapunov(Op):
777+
"""
778+
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.
779+
780+
Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
781+
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
782+
scipy.linalg.solve_continuous_lyapunov
783+
"""
784+
780785
__props__ = ()
786+
gufunc_signature = "(m,m),(m,m)->(m,m)"
781787

782788
def make_node(self, A, B):
783789
A = as_tensor_variable(A)
@@ -792,7 +798,8 @@ def perform(self, node, inputs, output_storage):
792798
(A, B) = inputs
793799
X = output_storage[0]
794800

795-
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)
801+
out_dtype = node.outputs[0].type.dtype
802+
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
796803

797804
def infer_shape(self, fgraph, node, shapes):
798805
return [shapes[0]]
@@ -813,7 +820,41 @@ def grad(self, inputs, output_grads):
813820
return [A_bar, Q_bar]
814821

815822

823+
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())
824+
825+
826+
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
827+
"""
828+
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
829+
830+
Parameters
831+
----------
832+
A: TensorLike
833+
Square matrix of shape ``N x N``.
834+
Q: TensorLike
835+
Square matrix of shape ``N x N``.
836+
837+
Returns
838+
-------
839+
X: TensorVariable
840+
Square matrix of shape ``N x N``
841+
842+
"""
843+
844+
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
845+
846+
816847
class BilinearSolveDiscreteLyapunov(Op):
848+
"""
849+
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
850+
851+
The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
852+
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
853+
docstring for scipy.linalg.solve_discrete_lyapunov
854+
"""
855+
856+
gufunc_signature = "(m,m),(m,m)->(m,m)"
857+
817858
def make_node(self, A, B):
818859
A = as_tensor_variable(A)
819860
B = as_tensor_variable(B)
@@ -827,7 +868,10 @@ def perform(self, node, inputs, output_storage):
827868
(A, B) = inputs
828869
X = output_storage[0]
829870

830-
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")
871+
out_dtype = node.outputs[0].type.dtype
872+
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
873+
out_dtype
874+
)
831875

832876
def infer_shape(self, fgraph, node, shapes):
833877
return [shapes[0]]
@@ -849,83 +893,83 @@ def grad(self, inputs, output_grads):
849893
return [A_bar, Q_bar]
850894

851895

852-
_solve_continuous_lyapunov = SolveContinuousLyapunov()
853-
_solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov())
896+
_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov())
854897

855898

856-
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
857-
A_ = as_tensor_variable(A)
858-
Q_ = as_tensor_variable(Q)
899+
def _direct_solve_discrete_lyapunov(
900+
A: TensorVariable, Q: TensorVariable
901+
) -> TensorVariable:
902+
r"""
903+
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
904+
Neudecker.
905+
906+
This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
907+
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
908+
"""
859909

860-
if "complex" in A_.type.dtype:
861-
AA = kron(A_, A_.conj())
910+
if A.type.dtype.startswith("complex"):
911+
AxA = kron(A, A.conj())
862912
else:
863-
AA = kron(A_, A_)
913+
AxA = kron(A, A)
914+
915+
eye = pt.eye(AxA.shape[-1])
864916

865-
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
866-
return cast(TensorVariable, reshape(X, Q_.shape))
917+
vec_Q = Q.ravel()
918+
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
919+
920+
return cast(TensorVariable, reshape(vec_X, A.shape))
867921

868922

869923
def solve_discrete_lyapunov(
870-
A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct"
924+
A: TensorLike,
925+
Q: TensorLike,
926+
method: Literal["direct", "bilinear"] = "bilinear",
871927
) -> TensorVariable:
872928
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
873929
874930
Parameters
875931
----------
876-
A
877-
Square matrix of shape N x N; must have the same shape as Q
878-
Q
879-
Square matrix of shape N x N; must have the same shape as A
880-
method
881-
Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"``
882-
solves the problem directly via matrix inversion. This has a pure
883-
PyTensor implementation and can thus be cross-compiled to supported
884-
backends, and should be preferred when ``N`` is not large. The direct
885-
method scales poorly with the size of ``N``, and the bilinear can be
932+
A: TensorLike
933+
Square matrix of shape N x N
934+
Q: TensorLike
935+
Square matrix of shape N x N
936+
method: str, one of ``"direct"`` or ``"bilinear"``
937+
Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
938+
PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
939+
``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
886940
used in these cases.
887941
888942
Returns
889943
-------
890-
Square matrix of shape ``N x N``, representing the solution to the
891-
Lyapunov equation
944+
X: TensorVariable
945+
Square matrix of shape ``N x N``. Solution to the Lyapunov equation
892946
893947
"""
894948
if method not in ["direct", "bilinear"]:
895949
raise ValueError(
896950
f'Parameter "method" must be one of "direct" or "bilinear", found {method}'
897951
)
898952

899-
if method == "direct":
900-
return _direct_solve_discrete_lyapunov(A, Q)
901-
if method == "bilinear":
902-
return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
903-
904-
905-
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
906-
"""Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
907-
908-
Parameters
909-
----------
910-
A
911-
Square matrix of shape ``N x N``; must have the same shape as `Q`.
912-
Q
913-
Square matrix of shape ``N x N``; must have the same shape as `A`.
953+
A = as_tensor_variable(A)
954+
Q = as_tensor_variable(Q)
914955

915-
Returns
916-
-------
917-
Square matrix of shape ``N x N``, representing the solution to the
918-
Lyapunov equation
956+
if method == "direct":
957+
signature = BilinearSolveDiscreteLyapunov.gufunc_signature
958+
X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
959+
return cast(TensorVariable, X)
919960

920-
"""
961+
elif method == "bilinear":
962+
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q))
921963

922-
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
964+
else:
965+
raise ValueError(f"Unknown method {method}")
923966

924967

925-
class SolveDiscreteARE(pt.Op):
968+
class SolveDiscreteARE(Op):
926969
__props__ = ("enforce_Q_symmetric",)
970+
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
927971

928-
def __init__(self, enforce_Q_symmetric=False):
972+
def __init__(self, enforce_Q_symmetric: bool = False):
929973
self.enforce_Q_symmetric = enforce_Q_symmetric
930974

931975
def make_node(self, A, B, Q, R):
@@ -946,9 +990,8 @@ def perform(self, node, inputs, output_storage):
946990
if self.enforce_Q_symmetric:
947991
Q = 0.5 * (Q + Q.T)
948992

949-
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(
950-
node.outputs[0].type.dtype
951-
)
993+
out_dtype = node.outputs[0].type.dtype
994+
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
952995

953996
def infer_shape(self, fgraph, node, shapes):
954997
return [shapes[0]]
@@ -960,14 +1003,16 @@ def grad(self, inputs, output_grads):
9601003
(dX,) = output_grads
9611004
X = self(A, B, Q, R)
9621005

963-
K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
964-
K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0]))
965-
K = matrix_dot(K_inner_inv, B.T, X, A)
1006+
K_inner = R + matrix_dot(B.T, X, B)
1007+
1008+
# K_inner is guaranteed to be symmetric, because X and R are symmetric
1009+
K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym")
1010+
K = matrix_dot(K_inner_inv_BT, X, A)
9661011

9671012
A_tilde = A - B.dot(K)
9681013

9691014
dX_symm = 0.5 * (dX + dX.T)
970-
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)
1015+
S = solve_discrete_lyapunov(A_tilde, dX_symm)
9711016

9721017
A_bar = 2 * matrix_dot(X, A_tilde, S)
9731018
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
@@ -977,30 +1022,45 @@ def grad(self, inputs, output_grads):
9771022
return [A_bar, B_bar, Q_bar, R_bar]
9781023

9791024

980-
def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
1025+
def solve_discrete_are(
1026+
A: TensorLike,
1027+
B: TensorLike,
1028+
Q: TensorLike,
1029+
R: TensorLike,
1030+
enforce_Q_symmetric: bool = False,
1031+
) -> TensorVariable:
9811032
"""
9821033
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
9831034
1035+
Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
1036+
solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
1037+
steady-state covariance of the Kalman Filter.
1038+
1039+
Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
1040+
solution. This stable solution, if it exists, will be returned by this function.
1041+
9841042
Parameters
9851043
----------
986-
A: ArrayLike
1044+
A: TensorLike
9871045
Square matrix of shape M x M
988-
B: ArrayLike
1046+
B: TensorLike
9891047
Square matrix of shape M x M
990-
Q: ArrayLike
1048+
Q: TensorLike
9911049
Symmetric square matrix of shape M x M
992-
R: ArrayLike
1050+
R: TensorLike
9931051
Square matrix of shape N x N
9941052
enforce_Q_symmetric: bool
9951053
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
9961054
9971055
Returns
9981056
-------
999-
X: pt.matrix
1057+
X: TensorVariable
10001058
Square matrix of shape M x M, representing the solution to the DARE
10011059
"""
10021060

1003-
return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R))
1061+
return cast(
1062+
TensorVariable, Blockwise(SolveDiscreteARE(enforce_Q_symmetric))(A, B, Q, R)
1063+
)
10041064

10051065

10061066
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:

0 commit comments

Comments
 (0)