Skip to content

Commit 7867b16

Browse files
committed
Added test for sampling prior and posterior predictives from a mixture based on issue pymc-devs#3270
1 parent a3f7caa commit 7867b16

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

pymc3/tests/test_mixture.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from numpy.testing import assert_allclose
33

44
from .helpers import SeededTest
5+
import pymc3 as pm
56
from pymc3 import Dirichlet, Gamma, Normal, Lognormal, Poisson, Exponential, \
67
Mixture, NormalMixture, MvNormal, sample, Metropolis, Model
78
import scipy.stats as st
@@ -231,3 +232,53 @@ def mixmixlogp(value, point):
231232
assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point))
232233
assert_allclose(priorlogp + mixmixlogpg.sum(),
233234
model.logp(test_point))
235+
236+
def test_sample_prior_and_posterior(self):
237+
def build_toy_dataset(N, K):
238+
pi = np.array([0.2, 0.5, 0.3])
239+
mus = [[1, 1], [-1, -1], [2, -2]]
240+
stds = [[0.1, 0.1], [0.1, 0.2], [0.2, 0.3]]
241+
x = np.zeros((N, 2), dtype=np.float32)
242+
y = np.zeros((N,), dtype=np.int)
243+
for n in range(N):
244+
k = np.argmax(np.random.multinomial(1, pi))
245+
x[n, :] = np.random.multivariate_normal(mus[k],
246+
np.diag(stds[k]))
247+
y[n] = k
248+
return x, y
249+
250+
N = 100 # number of data points
251+
D = 2 # dimensionality of data
252+
253+
X, y = build_toy_dataset(N, 3)
254+
255+
K = 3
256+
with pm.Model() as model:
257+
pi = pm.Dirichlet('pi', np.ones(K))
258+
259+
comp_dist = []
260+
mu = []
261+
packed_chol = []
262+
chol = []
263+
for i in range(K):
264+
mu.append(pm.Normal('mu%i' % i, 0, 10, shape=2))
265+
packed_chol.append(
266+
pm.LKJCholeskyCov('chol_cov_%i' % i,
267+
eta=2,
268+
n=2,
269+
sd_dist=pm.HalfNormal.dist(2.5))
270+
)
271+
chol.append(pm.expand_packed_triangular(2, packed_chol[i],
272+
lower=True))
273+
comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i]))
274+
275+
pm.Mixture('x_obs', pi, comp_dist, observed=X)
276+
with model:
277+
trace = pm.sample(30, tune=10, chains=1)
278+
279+
n_samples = 20
280+
with model:
281+
ppc = pm.sample_posterior_predictive(trace, n_samples)
282+
prior = pm.sample_prior_predictive(samples=n_samples)
283+
assert ppc['x_obs'].shape == (n_samples,) + X.shape
284+
assert prior['x_obs'].shape == (n_samples, D)

0 commit comments

Comments
 (0)