|
4 | 4 | import warnings
|
5 | 5 |
|
6 | 6 | from .dist_math import *
|
| 7 | +from . import ChiSquared, Normal |
| 8 | +from .. import Deterministic |
7 | 9 |
|
8 | 10 | import numpy as np
|
| 11 | +import scipy |
9 | 12 | from .transforms import simplextransform
|
10 | 13 |
|
11 | 14 | from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
|
12 |
| -from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all |
| 15 | +from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all, zeros, sqrt, set_subtensor |
13 | 16 | from theano.printing import Print
|
14 | 17 |
|
15 | 18 | __all__ = ['MvNormal', 'Dirichlet', 'Multinomial', 'Wishart', 'WishartBartlett', 'LKJCorr']
|
@@ -234,18 +237,18 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
|
234 | 237 | tril_idx = np.tril_indices_from(S, k=-1)
|
235 | 238 | n_diag = len(diag_idx[0])
|
236 | 239 | n_tril = len(tril_idx[0])
|
237 |
| - c = T.sqrt(pm.ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag)) |
238 |
| - z = pm.Normal('z', 0, 1, shape=n_tril) |
| 240 | + c = sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag)) |
| 241 | + z = Normal('z', 0, 1, shape=n_tril) |
239 | 242 | # Construct A matrix
|
240 |
| - A = T.zeros(S.shape, dtype=np.float32) |
241 |
| - A = T.set_subtensor(A[diag_idx], c) |
242 |
| - A = T.set_subtensor(A[tril_idx], z) |
| 243 | + A = zeros(S.shape, dtype=np.float32) |
| 244 | + A = set_subtensor(A[diag_idx], c) |
| 245 | + A = set_subtensor(A[tril_idx], z) |
243 | 246 |
|
244 | 247 | # L * A * A.T * L.T ~ Wishart(L*L.T, nu)
|
245 | 248 | if return_cholesky:
|
246 |
| - return pm.Deterministic(name, T.dot(L, A)) |
| 249 | + return Deterministic(name, dot(L, A)) |
247 | 250 | else:
|
248 |
| - return pm.Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T)) |
| 251 | + return Deterministic(name, dot(dot(dot(L, A), A.T), L.T)) |
249 | 252 |
|
250 | 253 |
|
251 | 254 | class LKJCorr(Continuous):
|
|
0 commit comments