Skip to content

Commit a3521e8

Browse files
Added working KroneckerNormal random tests
1 parent 2239bf2 commit a3521e8

File tree

1 file changed

+54
-2
lines changed

1 file changed

+54
-2
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
from .test_distributions import (
1414
build_model, Domain, product, R, Rplus, Rplusbig, Rplusdunif,
1515
Unit, Nat, NatSmall, I, Simplex, Vector, PdMatrix,
16-
PdMatrixChol, PdMatrixCholUpper, RealMatrix
16+
PdMatrixChol, PdMatrixCholUpper, RealMatrix, RandomPdMatrix
1717
)
1818

1919

2020
def pymc3_random(dist, paramdomains, ref_rand, valuedomain=Domain([0]),
21-
size=10000, alpha=0.05, fails=10, extra_args=None):
21+
size=10000, alpha=0.05, fails=10, extra_args=None,
22+
model_args=None):
23+
if model_args is None:
24+
model_args = {}
2225
model = build_model(dist, valuedomain, paramdomains, extra_args)
2326
domains = paramdomains.copy()
2427
for pt in product(domains, n_samples=100):
2528
pt = pm.Point(pt, model=model)
29+
pt.update(model_args)
2630
p = alpha
2731
# Allow KS test to fail (i.e., the samples be different)
2832
# a certain number of times. Crude, but necessary.
@@ -586,6 +590,54 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
586590
# extra_args={'lower': False}
587591
# )
588592

593+
def test_kronecker_normal(self):
594+
def ref_rand(size, mu, covs, sigma):
595+
cov = pm.math.kronecker(covs[0], covs[1]).eval()
596+
cov += sigma**2 * np.identity(cov.shape[0])
597+
return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size)
598+
599+
def ref_rand_chol(size, mu, chols, sigma):
600+
covs = [np.dot(chol, chol.T) for chol in chols]
601+
return ref_rand(size, mu, covs, sigma)
602+
603+
def ref_rand_evd(size, mu, evds, sigma):
604+
covs = []
605+
for eigs, Q in evds:
606+
covs.append(np.dot(Q, np.dot(np.diag(eigs), Q.T)))
607+
return ref_rand(size, mu, covs, sigma)
608+
609+
sizes = [2, 3]
610+
sigmas = [0, 1]
611+
for n, sigma in zip(sizes, sigmas):
612+
N = n**2
613+
covs = [RandomPdMatrix(n), RandomPdMatrix(n)]
614+
chols = list(map(np.linalg.cholesky, covs))
615+
evds = list(map(np.linalg.eigh, covs))
616+
dom = Domain([np.random.randn(N)*0.1], edges=(None, None), shape=N)
617+
mu = Domain([np.random.randn(N)*0.1], edges=(None, None), shape=N)
618+
619+
std_args = {'mu': mu}
620+
cov_args = {'covs': covs}
621+
chol_args = {'chols': chols}
622+
evd_args = {'evds': evds}
623+
if sigma is not None and sigma != 0:
624+
std_args['sigma'] = Domain([sigma], edges=(None, None))
625+
else:
626+
for args in [cov_args, chol_args, evd_args]:
627+
args['sigma'] = sigma
628+
629+
pymc3_random(
630+
pm.KroneckerNormal, std_args, valuedomain=dom,
631+
ref_rand=ref_rand, extra_args=cov_args, model_args=cov_args)
632+
pymc3_random(
633+
pm.KroneckerNormal, std_args, valuedomain=dom,
634+
ref_rand=ref_rand_chol, extra_args=chol_args,
635+
model_args=chol_args)
636+
pymc3_random(
637+
pm.KroneckerNormal, std_args, valuedomain=dom,
638+
ref_rand=ref_rand_evd, extra_args=evd_args,
639+
model_args=evd_args)
640+
589641
def test_mv_t(self):
590642
def ref_rand(size, nu, Sigma, mu):
591643
normal = st.multivariate_normal.rvs(cov=Sigma, size=size).T

0 commit comments

Comments
 (0)