|
4 | 4 | import warnings
|
5 | 5 |
|
6 | 6 | import numpy as np
|
| 7 | +import scipy |
| 8 | + |
7 | 9 | import theano.tensor as T
|
8 | 10 | from scipy import stats
|
9 | 11 | from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
|
10 | 12 |
|
11 | 13 | from . import transforms
|
12 | 14 | from .distribution import Continuous, Discrete, draw_values, generate_samples
|
| 15 | +from ..model import Deterministic |
| 16 | +from .continuous import ChiSquared, Normal |
13 | 17 | from .special import gammaln, multigammaln
|
14 | 18 | from .dist_math import bound, logpow, factln
|
15 | 19 |
|
@@ -260,6 +264,7 @@ def logp(self, X):
|
260 | 264 | T.all(eigh(X)[0] > 0), T.eq(X, X.T),
|
261 | 265 | n > (p - 1))
|
262 | 266 |
|
| 267 | + |
263 | 268 | def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
|
264 | 269 | """
|
265 | 270 | Bartlett decomposition of the Wishart distribution. As the Wishart
|
@@ -303,18 +308,20 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
|
303 | 308 | tril_idx = np.tril_indices_from(S, k=-1)
|
304 | 309 | n_diag = len(diag_idx[0])
|
305 | 310 | n_tril = len(tril_idx[0])
|
306 |
| - c = T.sqrt(pm.ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag)) |
307 |
| - z = pm.Normal('z', 0, 1, shape=n_tril) |
| 311 | + c = T.sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag)) |
| 312 | + print('Added new variable c to model diagonal of Wishart.') |
| 313 | + z = Normal('z', 0, 1, shape=n_tril) |
| 314 | + print('Added new variable z to model off-diagonals of Wishart.') |
308 | 315 | # Construct A matrix
|
309 | 316 | A = T.zeros(S.shape, dtype=np.float32)
|
310 | 317 | A = T.set_subtensor(A[diag_idx], c)
|
311 | 318 | A = T.set_subtensor(A[tril_idx], z)
|
312 | 319 |
|
313 | 320 | # L * A * A.T * L.T ~ Wishart(L*L.T, nu)
|
314 | 321 | if return_cholesky:
|
315 |
| - return pm.Deterministic(name, T.dot(L, A)) |
| 322 | + return Deterministic(name, T.dot(L, A)) |
316 | 323 | else:
|
317 |
| - return pm.Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T)) |
| 324 | + return Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T)) |
318 | 325 |
|
319 | 326 |
|
320 | 327 | class LKJCorr(Continuous):
|
|
0 commit comments