2
2
import typing
3
3
import warnings
4
4
from functools import reduce
5
- from typing import TYPE_CHECKING , Literal , cast
5
+ from typing import Literal , cast
6
6
7
7
import numpy as np
8
8
import scipy .linalg
11
11
import pytensor .tensor as pt
12
12
from pytensor .graph .basic import Apply
13
13
from pytensor .graph .op import Op
14
- from pytensor .tensor import as_tensor_variable
14
+ from pytensor .tensor import TensorLike , as_tensor_variable
15
15
from pytensor .tensor import basic as ptb
16
16
from pytensor .tensor import math as ptm
17
17
from pytensor .tensor .blockwise import Blockwise
21
21
from pytensor .tensor .variable import TensorVariable
22
22
23
23
24
- if TYPE_CHECKING :
25
- from pytensor .tensor import TensorLike
26
-
27
24
logger = logging .getLogger (__name__ )
28
25
29
26
@@ -777,7 +774,16 @@ def perform(self, node, inputs, outputs):
777
774
778
775
779
776
class SolveContinuousLyapunov (Op ):
777
+ """
778
+ Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.
779
+
780
+ Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
781
+ efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
782
+ scipy.linalg.solve_continuous_lyapunov
783
+ """
784
+
780
785
__props__ = ()
786
+ gufunc_signature = "(m,m),(m,m)->(m,m)"
781
787
782
788
def make_node (self , A , B ):
783
789
A = as_tensor_variable (A )
@@ -792,7 +798,8 @@ def perform(self, node, inputs, output_storage):
792
798
(A , B ) = inputs
793
799
X = output_storage [0 ]
794
800
795
- X [0 ] = scipy .linalg .solve_continuous_lyapunov (A , B )
801
+ out_dtype = node .outputs [0 ].type .dtype
802
+ X [0 ] = scipy .linalg .solve_continuous_lyapunov (A , B ).astype (out_dtype )
796
803
797
804
def infer_shape (self , fgraph , node , shapes ):
798
805
return [shapes [0 ]]
@@ -813,7 +820,41 @@ def grad(self, inputs, output_grads):
813
820
return [A_bar , Q_bar ]
814
821
815
822
823
+ _solve_continuous_lyapunov = Blockwise (SolveContinuousLyapunov ())
824
+
825
+
826
+ def solve_continuous_lyapunov (A : TensorLike , Q : TensorLike ) -> TensorVariable :
827
+ """
828
+ Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
829
+
830
+ Parameters
831
+ ----------
832
+ A: TensorLike
833
+ Square matrix of shape ``N x N``.
834
+ Q: TensorLike
835
+ Square matrix of shape ``N x N``.
836
+
837
+ Returns
838
+ -------
839
+ X: TensorVariable
840
+ Square matrix of shape ``N x N``
841
+
842
+ """
843
+
844
+ return cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
845
+
846
+
816
847
class BilinearSolveDiscreteLyapunov (Op ):
848
+ """
849
+ Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
850
+
851
+ The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
852
+ time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
853
+ docstring for scipy.linalg.solve_discrete_lyapunov
854
+ """
855
+
856
+ gufunc_signature = "(m,m),(m,m)->(m,m)"
857
+
817
858
def make_node (self , A , B ):
818
859
A = as_tensor_variable (A )
819
860
B = as_tensor_variable (B )
@@ -827,7 +868,10 @@ def perform(self, node, inputs, output_storage):
827
868
(A , B ) = inputs
828
869
X = output_storage [0 ]
829
870
830
- X [0 ] = scipy .linalg .solve_discrete_lyapunov (A , B , method = "bilinear" )
871
+ out_dtype = node .outputs [0 ].type .dtype
872
+ X [0 ] = scipy .linalg .solve_discrete_lyapunov (A , B , method = "bilinear" ).astype (
873
+ out_dtype
874
+ )
831
875
832
876
def infer_shape (self , fgraph , node , shapes ):
833
877
return [shapes [0 ]]
@@ -849,83 +893,83 @@ def grad(self, inputs, output_grads):
849
893
return [A_bar , Q_bar ]
850
894
851
895
852
- _solve_continuous_lyapunov = SolveContinuousLyapunov ()
853
- _solve_bilinear_direct_lyapunov = cast (typing .Callable , BilinearSolveDiscreteLyapunov ())
896
+ _bilinear_solve_discrete_lyapunov = Blockwise (BilinearSolveDiscreteLyapunov ())
854
897
855
898
856
- def _direct_solve_discrete_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
857
- A_ = as_tensor_variable (A )
858
- Q_ = as_tensor_variable (Q )
899
+ def _direct_solve_discrete_lyapunov (
900
+ A : TensorVariable , Q : TensorVariable
901
+ ) -> TensorVariable :
902
+ r"""
903
+ Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
904
+ Neudecker.
905
+
906
+ This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
907
+ As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
908
+ """
859
909
860
- if "complex" in A_ .type .dtype :
861
- AA = kron (A_ , A_ .conj ())
910
+ if A .type .dtype . startswith ( "complex" ) :
911
+ AxA = kron (A , A .conj ())
862
912
else :
863
- AA = kron (A_ , A_ )
913
+ AxA = kron (A , A )
914
+
915
+ eye = pt .eye (AxA .shape [- 1 ])
864
916
865
- X = solve (pt .eye (AA .shape [0 ]) - AA , Q_ .ravel ())
866
- return cast (TensorVariable , reshape (X , Q_ .shape ))
917
+ vec_Q = Q .ravel ()
918
+ vec_X = solve (eye - AxA , vec_Q , b_ndim = 1 )
919
+
920
+ return cast (TensorVariable , reshape (vec_X , A .shape ))
867
921
868
922
869
923
def solve_discrete_lyapunov (
870
- A : "TensorLike" , Q : "TensorLike" , method : Literal ["direct" , "bilinear" ] = "direct"
924
+ A : TensorLike ,
925
+ Q : TensorLike ,
926
+ method : Literal ["direct" , "bilinear" ] = "bilinear" ,
871
927
) -> TensorVariable :
872
928
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
873
929
874
930
Parameters
875
931
----------
876
- A
877
- Square matrix of shape N x N; must have the same shape as Q
878
- Q
879
- Square matrix of shape N x N; must have the same shape as A
880
- method
881
- Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"``
882
- solves the problem directly via matrix inversion. This has a pure
883
- PyTensor implementation and can thus be cross-compiled to supported
884
- backends, and should be preferred when ``N`` is not large. The direct
885
- method scales poorly with the size of ``N``, and the bilinear can be
932
+ A: TensorLike
933
+ Square matrix of shape N x N
934
+ Q: TensorLike
935
+ Square matrix of shape N x N
936
+ method: str, one of ``"direct"`` or ``"bilinear"``
937
+ Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
938
+ PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
939
+ ``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
886
940
used in these cases.
887
941
888
942
Returns
889
943
-------
890
- Square matrix of shape ``N x N``, representing the solution to the
891
- Lyapunov equation
944
+ X: TensorVariable
945
+ Square matrix of shape ``N x N``. Solution to the Lyapunov equation
892
946
893
947
"""
894
948
if method not in ["direct" , "bilinear" ]:
895
949
raise ValueError (
896
950
f'Parameter "method" must be one of "direct" or "bilinear", found { method } '
897
951
)
898
952
899
- if method == "direct" :
900
- return _direct_solve_discrete_lyapunov (A , Q )
901
- if method == "bilinear" :
902
- return cast (TensorVariable , _solve_bilinear_direct_lyapunov (A , Q ))
903
-
904
-
905
- def solve_continuous_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
906
- """Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
907
-
908
- Parameters
909
- ----------
910
- A
911
- Square matrix of shape ``N x N``; must have the same shape as `Q`.
912
- Q
913
- Square matrix of shape ``N x N``; must have the same shape as `A`.
953
+ A = as_tensor_variable (A )
954
+ Q = as_tensor_variable (Q )
914
955
915
- Returns
916
- -------
917
- Square matrix of shape ``N x N``, representing the solution to the
918
- Lyapunov equation
956
+ if method == "direct" :
957
+ signature = BilinearSolveDiscreteLyapunov . gufunc_signature
958
+ X = pt . vectorize ( _direct_solve_discrete_lyapunov , signature = signature )( A , Q )
959
+ return cast ( TensorVariable , X )
919
960
920
- """
961
+ elif method == "bilinear" :
962
+ return cast (TensorVariable , _bilinear_solve_discrete_lyapunov (A , Q ))
921
963
922
- return cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
964
+ else :
965
+ raise ValueError (f"Unknown method { method } " )
923
966
924
967
925
- class SolveDiscreteARE (pt . Op ):
968
+ class SolveDiscreteARE (Op ):
926
969
__props__ = ("enforce_Q_symmetric" ,)
970
+ gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
927
971
928
- def __init__ (self , enforce_Q_symmetric = False ):
972
+ def __init__ (self , enforce_Q_symmetric : bool = False ):
929
973
self .enforce_Q_symmetric = enforce_Q_symmetric
930
974
931
975
def make_node (self , A , B , Q , R ):
@@ -946,9 +990,8 @@ def perform(self, node, inputs, output_storage):
946
990
if self .enforce_Q_symmetric :
947
991
Q = 0.5 * (Q + Q .T )
948
992
949
- X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R ).astype (
950
- node .outputs [0 ].type .dtype
951
- )
993
+ out_dtype = node .outputs [0 ].type .dtype
994
+ X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R ).astype (out_dtype )
952
995
953
996
def infer_shape (self , fgraph , node , shapes ):
954
997
return [shapes [0 ]]
@@ -960,14 +1003,16 @@ def grad(self, inputs, output_grads):
960
1003
(dX ,) = output_grads
961
1004
X = self (A , B , Q , R )
962
1005
963
- K_inner = R + pt .linalg .matrix_dot (B .T , X , B )
964
- K_inner_inv = pt .linalg .solve (K_inner , pt .eye (R .shape [0 ]))
965
- K = matrix_dot (K_inner_inv , B .T , X , A )
1006
+ K_inner = R + matrix_dot (B .T , X , B )
1007
+
1008
+ # K_inner is guaranteed to be symmetric, because X and R are symmetric
1009
+ K_inner_inv_BT = solve (K_inner , B .T , assume_a = "sym" )
1010
+ K = matrix_dot (K_inner_inv_BT , X , A )
966
1011
967
1012
A_tilde = A - B .dot (K )
968
1013
969
1014
dX_symm = 0.5 * (dX + dX .T )
970
- S = solve_discrete_lyapunov (A_tilde , dX_symm ). astype ( dX . type . dtype )
1015
+ S = solve_discrete_lyapunov (A_tilde , dX_symm )
971
1016
972
1017
A_bar = 2 * matrix_dot (X , A_tilde , S )
973
1018
B_bar = - 2 * matrix_dot (X , A_tilde , S , K .T )
@@ -977,30 +1022,45 @@ def grad(self, inputs, output_grads):
977
1022
return [A_bar , B_bar , Q_bar , R_bar ]
978
1023
979
1024
980
- def solve_discrete_are (A , B , Q , R , enforce_Q_symmetric = False ) -> TensorVariable :
1025
+ def solve_discrete_are (
1026
+ A : TensorLike ,
1027
+ B : TensorLike ,
1028
+ Q : TensorLike ,
1029
+ R : TensorLike ,
1030
+ enforce_Q_symmetric : bool = False ,
1031
+ ) -> TensorVariable :
981
1032
"""
982
1033
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
983
1034
1035
+ Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
1036
+ solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
1037
+ steady-state covariance of the Kalman Filter.
1038
+
1039
+ Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
1040
+ solution. This stable solution, if it exists, will be returned by this function.
1041
+
984
1042
Parameters
985
1043
----------
986
- A: ArrayLike
1044
+ A: TensorLike
987
1045
Square matrix of shape M x M
988
- B: ArrayLike
1046
+ B: TensorLike
989
1047
Square matrix of shape M x M
990
- Q: ArrayLike
1048
+ Q: TensorLike
991
1049
Symmetric square matrix of shape M x M
992
- R: ArrayLike
1050
+ R: TensorLike
993
1051
Square matrix of shape N x N
994
1052
enforce_Q_symmetric: bool
995
1053
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
996
1054
997
1055
Returns
998
1056
-------
999
- X: pt.matrix
1057
+ X: TensorVariable
1000
1058
Square matrix of shape M x M, representing the solution to the DARE
1001
1059
"""
1002
1060
1003
- return cast (TensorVariable , SolveDiscreteARE (enforce_Q_symmetric )(A , B , Q , R ))
1061
+ return cast (
1062
+ TensorVariable , Blockwise (SolveDiscreteARE (enforce_Q_symmetric ))(A , B , Q , R )
1063
+ )
1004
1064
1005
1065
1006
1066
def _largest_common_dtype (tensors : typing .Sequence [TensorVariable ]) -> np .dtype :
0 commit comments