Skip to content

Commit 642f639

Browse files
committed
BUG Fix bounds of Wishart and add warning.
The Wishart has never worked correctly and produced divergent samples. One of the reasons was that we did not enforce the matrix to be symmetric nor positive semi-definite. Now, however, it is impossible to randomly sample a valid matrix so we raise a warning that the Wishart is basically unusable currently and to instead use the LKJ prior which has many desriable properties. For more discussions, see #538.
1 parent a962aeb commit 642f639

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

pymc3/distributions/multivariate.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import warnings
2+
13
from .dist_math import *
24

35
import numpy as np
46

5-
from theano.tensor.nlinalg import det, matrix_inverse, trace
6-
from theano.tensor import dot, cast, eye, diag, eq, le, ge, all
7+
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
8+
from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all
79
from theano.printing import Print
810

911
__all__ = ['MvNormal', 'Dirichlet', 'Multinomial', 'Wishart', 'LKJCorr']
@@ -160,6 +162,7 @@ class Wishart(Continuous):
160162
"""
161163
def __init__(self, n, V, *args, **kwargs):
162164
super(Wishart, self).__init__(*args, **kwargs)
165+
warnings.warn('The Wishart distribution can currently not be used for MCMC sampling. The probability of sampling a symmetric matrix is basically zero. Instead, please use the LKJCorr prior. For more information on the issues surrounding the Wishart see here: https://github.com/pymc-devs/pymc3/issues/538.', UserWarning)
163166
self.n = n
164167
self.p = p = V.shape[0]
165168
self.V = V
@@ -179,7 +182,10 @@ def logp(self, X):
179182
return bound(
180183
((n - p - 1) * log(IXI) - trace(matrix_inverse(V).dot(X)) -
181184
n * p * log(2) - n * log(IVI) - 2 * multigammaln(n / 2., p)) / 2,
182-
n > (p - 1))
185+
gt(n, (p - 1)),
186+
all(gt(eigh(X)[0], 0)),
187+
eq(X, X.T)
188+
)
183189

184190

185191
class LKJCorr(Continuous):
@@ -249,7 +255,7 @@ def logp(self, x):
249255

250256
X = x[self.tri_index]
251257
X = t.fill_diagonal(X, 1)
252-
258+
253259
result = self._normalizing_constant(n, p)
254260
result += (n - 1.) * log(det(X))
255261
return bound(result,

0 commit comments

Comments
 (0)