Skip to content

Commit c33d571

Browse files
committed
Added a test to signify need of transpose of cholesky matrix.
1 parent b69c109 commit c33d571

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,12 @@ def ref_rand_chol(size, mu, rowchol, colchol):
849849
size, mu, rowcov=np.dot(rowchol, rowchol.T), colcov=np.dot(colchol, colchol.T)
850850
)
851851

852+
def ref_rand_chol_transpose(size, mu, rowchol, colchol):
853+
colchol = colchol.T
854+
return ref_rand(
855+
size, mu, rowcov=np.dot(rowchol, rowchol.T), colcov=np.dot(colchol, colchol.T)
856+
)
857+
852858
def ref_rand_uchol(size, mu, rowchol, colchol):
853859
return ref_rand(
854860
size, mu, rowcov=np.dot(rowchol.T, rowchol), colcov=np.dot(colchol.T, colchol)
@@ -858,7 +864,7 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
858864
pymc3_random(
859865
pm.MatrixNormal,
860866
{"mu": RealMatrix(n, n), "rowcov": PdMatrix(n), "colcov": PdMatrix(n)},
861-
size=n,
867+
size=100,
862868
valuedomain=RealMatrix(n, n),
863869
ref_rand=ref_rand,
864870
)
@@ -867,7 +873,7 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
867873
pymc3_random(
868874
pm.MatrixNormal,
869875
{"mu": RealMatrix(n, n), "rowchol": PdMatrixChol(n), "colchol": PdMatrixChol(n)},
870-
size=n,
876+
size=100,
871877
valuedomain=RealMatrix(n, n),
872878
ref_rand=ref_rand_chol,
873879
)
@@ -878,6 +884,22 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
878884
# extra_args={'lower': False}
879885
# )
880886

887+
# 2 sample test fails because cov becomes different if chol is transposed beforehand.
888+
# This implicity means we need transpose of chol after drawing values in
889+
# MatrixNormal.random method to match stats.matrix_normal.rvs method
890+
with pytest.raises(AssertionError):
891+
pymc3_random(
892+
pm.MatrixNormal,
893+
{
894+
"mu": RealMatrix(n, n),
895+
"rowchol": PdMatrixChol(n),
896+
"colchol": PdMatrixChol(n),
897+
},
898+
size=100,
899+
valuedomain=RealMatrix(n, n),
900+
ref_rand=ref_rand_chol_transpose,
901+
)
902+
881903
def test_kronecker_normal(self):
882904
def ref_rand(size, mu, covs, sigma):
883905
cov = pm.math.kronecker(covs[0], covs[1]).eval()

0 commit comments

Comments
 (0)