Skip to content

Commit b4f105f

Browse files
committed
Add WishartBartlett example.
1 parent 91b4359 commit b4f105f

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

pymc3/distributions/multivariate.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import warnings
55

66
from .dist_math import *
7+
from . import ChiSquared, Normal
8+
from .. import Deterministic
79

810
import numpy as np
11+
import scipy
912
from .transforms import simplextransform
1013

1114
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
12-
from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all
15+
from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all, zeros, sqrt, set_subtensor
1316
from theano.printing import Print
1417

1518
__all__ = ['MvNormal', 'Dirichlet', 'Multinomial', 'Wishart', 'WishartBartlett', 'LKJCorr']
@@ -234,18 +237,18 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
234237
tril_idx = np.tril_indices_from(S, k=-1)
235238
n_diag = len(diag_idx[0])
236239
n_tril = len(tril_idx[0])
237-
c = T.sqrt(pm.ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag))
238-
z = pm.Normal('z', 0, 1, shape=n_tril)
240+
c = sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag))
241+
z = Normal('z', 0, 1, shape=n_tril)
239242
# Construct A matrix
240-
A = T.zeros(S.shape, dtype=np.float32)
241-
A = T.set_subtensor(A[diag_idx], c)
242-
A = T.set_subtensor(A[tril_idx], z)
243+
A = zeros(S.shape, dtype=np.float32)
244+
A = set_subtensor(A[diag_idx], c)
245+
A = set_subtensor(A[tril_idx], z)
243246

244247
# L * A * A.T * L.T ~ Wishart(L*L.T, nu)
245248
if return_cholesky:
246-
return pm.Deterministic(name, T.dot(L, A))
249+
return Deterministic(name, dot(L, A))
247250
else:
248-
return pm.Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T))
251+
return Deterministic(name, dot(dot(dot(L, A), A.T), L.T))
249252

250253

251254
class LKJCorr(Continuous):

pymc3/examples/wishart.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pymc3 as pm
2+
import numpy as np
3+
import theano
4+
import theano.tensor as T
5+
import scipy.stats
6+
import matplotlib.pyplot as plt
7+
8+
# Covariance matrix we want to recover
9+
covariance = np.matrix([[2, .5, -.5],
10+
[.5, 1., 0.],
11+
[-.5, 0., 0.5]])
12+
13+
prec = np.linalg.inv(covariance)
14+
15+
mean = [.5, 1, .2]
16+
data = scipy.stats.multivariate_normal(mean, covariance).rvs(5000)
17+
18+
plt.scatter(data[:, 0], data[:, 1])
19+
20+
with pm.Model() as model:
21+
S = np.eye(3)
22+
nu = 5
23+
mu = pm.Normal('mu', mu=0, sd=1, shape=3)
24+
25+
# Use the transformed Wishart distribution
26+
# Under the hood this will do a Cholesky decomposition
27+
# of S and add two RVs to the sampler: c and z
28+
prec = pm.WishartBartlett('prec', S, nu)
29+
30+
# To be able to compare it to truth, convert precision to covariance
31+
cov = pm.Deterministic('cov', T.nlinalg.matrix_inverse(prec))
32+
33+
lp = pm.MvNormal('likelihood', mu=mu, tau=prec, observed=data)
34+
35+
start = pm.find_MAP()
36+
step = pm.NUTS(scaling=start)
37+
trace = pm.sample(500, step)
38+
39+
pm.traceplot(trace[100:]);

0 commit comments

Comments
 (0)