@@ -849,6 +849,12 @@ def ref_rand_chol(size, mu, rowchol, colchol):
849
849
size , mu , rowcov = np .dot (rowchol , rowchol .T ), colcov = np .dot (colchol , colchol .T )
850
850
)
851
851
852
+ def ref_rand_chol_transpose (size , mu , rowchol , colchol ):
853
+ colchol = colchol .T
854
+ return ref_rand (
855
+ size , mu , rowcov = np .dot (rowchol , rowchol .T ), colcov = np .dot (colchol , colchol .T )
856
+ )
857
+
852
858
def ref_rand_uchol (size , mu , rowchol , colchol ):
853
859
return ref_rand (
854
860
size , mu , rowcov = np .dot (rowchol .T , rowchol ), colcov = np .dot (colchol .T , colchol )
@@ -858,7 +864,7 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
858
864
pymc3_random (
859
865
pm .MatrixNormal ,
860
866
{"mu" : RealMatrix (n , n ), "rowcov" : PdMatrix (n ), "colcov" : PdMatrix (n )},
861
- size = n ,
867
+ size = 100 ,
862
868
valuedomain = RealMatrix (n , n ),
863
869
ref_rand = ref_rand ,
864
870
)
@@ -867,7 +873,7 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
867
873
pymc3_random (
868
874
pm .MatrixNormal ,
869
875
{"mu" : RealMatrix (n , n ), "rowchol" : PdMatrixChol (n ), "colchol" : PdMatrixChol (n )},
870
- size = n ,
876
+ size = 100 ,
871
877
valuedomain = RealMatrix (n , n ),
872
878
ref_rand = ref_rand_chol ,
873
879
)
@@ -878,6 +884,22 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
878
884
# extra_args={'lower': False}
879
885
# )
880
886
887
+ # 2 sample test fails because cov becomes different if chol is transposed beforehand.
888
+ # This implicity means we need transpose of chol after drawing values in
889
+ # MatrixNormal.random method to match stats.matrix_normal.rvs method
890
+ with pytest .raises (AssertionError ):
891
+ pymc3_random (
892
+ pm .MatrixNormal ,
893
+ {
894
+ "mu" : RealMatrix (n , n ),
895
+ "rowchol" : PdMatrixChol (n ),
896
+ "colchol" : PdMatrixChol (n ),
897
+ },
898
+ size = 100 ,
899
+ valuedomain = RealMatrix (n , n ),
900
+ ref_rand = ref_rand_chol_transpose ,
901
+ )
902
+
881
903
def test_kronecker_normal (self ):
882
904
def ref_rand (size , mu , covs , sigma ):
883
905
cov = pm .math .kronecker (covs [0 ], covs [1 ]).eval ()
0 commit comments