|
2 | 2 | from numpy.testing import assert_allclose
|
3 | 3 |
|
4 | 4 | from .helpers import SeededTest
|
| 5 | +import pymc3 as pm |
5 | 6 | from pymc3 import Dirichlet, Gamma, Normal, Lognormal, Poisson, Exponential, \
|
6 | 7 | Mixture, NormalMixture, MvNormal, sample, Metropolis, Model
|
7 | 8 | import scipy.stats as st
|
@@ -231,3 +232,53 @@ def mixmixlogp(value, point):
|
231 | 232 | assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point))
|
232 | 233 | assert_allclose(priorlogp + mixmixlogpg.sum(),
|
233 | 234 | model.logp(test_point))
|
| 235 | + |
| 236 | + def test_sample_prior_and_posterior(self): |
| 237 | + def build_toy_dataset(N, K): |
| 238 | + pi = np.array([0.2, 0.5, 0.3]) |
| 239 | + mus = [[1, 1], [-1, -1], [2, -2]] |
| 240 | + stds = [[0.1, 0.1], [0.1, 0.2], [0.2, 0.3]] |
| 241 | + x = np.zeros((N, 2), dtype=np.float32) |
| 242 | + y = np.zeros((N,), dtype=np.int) |
| 243 | + for n in range(N): |
| 244 | + k = np.argmax(np.random.multinomial(1, pi)) |
| 245 | + x[n, :] = np.random.multivariate_normal(mus[k], |
| 246 | + np.diag(stds[k])) |
| 247 | + y[n] = k |
| 248 | + return x, y |
| 249 | + |
| 250 | + N = 100 # number of data points |
| 251 | + D = 2 # dimensionality of data |
| 252 | + |
| 253 | + X, y = build_toy_dataset(N, 3) |
| 254 | + |
| 255 | + K = 3 |
| 256 | + with pm.Model() as model: |
| 257 | + pi = pm.Dirichlet('pi', np.ones(K)) |
| 258 | + |
| 259 | + comp_dist = [] |
| 260 | + mu = [] |
| 261 | + packed_chol = [] |
| 262 | + chol = [] |
| 263 | + for i in range(K): |
| 264 | + mu.append(pm.Normal('mu%i' % i, 0, 10, shape=2)) |
| 265 | + packed_chol.append( |
| 266 | + pm.LKJCholeskyCov('chol_cov_%i' % i, |
| 267 | + eta=2, |
| 268 | + n=2, |
| 269 | + sd_dist=pm.HalfNormal.dist(2.5)) |
| 270 | + ) |
| 271 | + chol.append(pm.expand_packed_triangular(2, packed_chol[i], |
| 272 | + lower=True)) |
| 273 | + comp_dist.append(pm.MvNormal.dist(mu=mu[i], chol=chol[i])) |
| 274 | + |
| 275 | + pm.Mixture('x_obs', pi, comp_dist, observed=X) |
| 276 | + with model: |
| 277 | + trace = pm.sample(30, tune=10, chains=1) |
| 278 | + |
| 279 | + n_samples = 20 |
| 280 | + with model: |
| 281 | + ppc = pm.sample_posterior_predictive(trace, n_samples) |
| 282 | + prior = pm.sample_prior_predictive(samples=n_samples) |
| 283 | + assert ppc['x_obs'].shape == (n_samples,) + X.shape |
| 284 | + assert prior['x_obs'].shape == (n_samples, D) |
0 commit comments