Skip to content

Commit 065c237

Browse files
committed
ENH Fix imports. Add verbosity arguments.
1 parent a48cc07 commit 065c237

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

pymc3/distributions/multivariate.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
import warnings
55

66
import numpy as np
7+
import scipy
8+
79
import theano.tensor as T
810
from scipy import stats
911
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
1012

1113
from . import transforms
1214
from .distribution import Continuous, Discrete, draw_values, generate_samples
15+
from ..model import Deterministic
16+
from .continuous import ChiSquared, Normal
1317
from .special import gammaln, multigammaln
1418
from .dist_math import bound, logpow, factln
1519

@@ -260,6 +264,7 @@ def logp(self, X):
260264
T.all(eigh(X)[0] > 0), T.eq(X, X.T),
261265
n > (p - 1))
262266

267+
263268
def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
264269
"""
265270
Bartlett decomposition of the Wishart distribution. As the Wishart
@@ -303,18 +308,20 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
303308
tril_idx = np.tril_indices_from(S, k=-1)
304309
n_diag = len(diag_idx[0])
305310
n_tril = len(tril_idx[0])
306-
c = T.sqrt(pm.ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag))
307-
z = pm.Normal('z', 0, 1, shape=n_tril)
311+
c = T.sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag))
312+
print('Added new variable c to model diagonal of Wishart.')
313+
z = Normal('z', 0, 1, shape=n_tril)
314+
print('Added new variable z to model off-diagonals of Wishart.')
308315
# Construct A matrix
309316
A = T.zeros(S.shape, dtype=np.float32)
310317
A = T.set_subtensor(A[diag_idx], c)
311318
A = T.set_subtensor(A[tril_idx], z)
312319

313320
# L * A * A.T * L.T ~ Wishart(L*L.T, nu)
314321
if return_cholesky:
315-
return pm.Deterministic(name, T.dot(L, A))
322+
return Deterministic(name, T.dot(L, A))
316323
else:
317-
return pm.Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T))
324+
return Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T))
318325

319326

320327
class LKJCorr(Continuous):

0 commit comments

Comments
 (0)