Skip to content

Commit 18c78f4

Browse files
authored
Remove redundant line in _QuadFormBase (#3630)
* Remove redundant line in `_QuadFormBase` * Fix accidental typo
1 parent 88c9434 commit 18c78f4

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

pymc3/distributions/multivariate.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ def _quaddist(self, value):
107107

108108
def _quaddist_chol(self, delta):
109109
chol_cov = self.chol_cov
110-
_, k = delta.shape
111-
k = pm.floatX(k)
112110
diag = tt.nlinalg.diag(chol_cov)
113111
# Check if the covariance matrix is positive definite.
114112
ok = tt.all(diag > 0)
@@ -126,14 +124,13 @@ def _quaddist_cov(self, delta):
126124

127125
def _quaddist_tau(self, delta):
128126
chol_tau = self.chol_tau
129-
_, k = delta.shape
130-
k = pm.floatX(k)
131-
132127
diag = tt.nlinalg.diag(chol_tau)
128+
# Check if the precision matrix is positive definite.
133129
ok = tt.all(diag > 0)
134-
130+
# If not, replace the diagonal. We return -inf later, but
131+
# need to prevent solve_lower from throwing an exception.
135132
chol_tau = tt.switch(ok, chol_tau, 1)
136-
diag = tt.nlinalg.diag(chol_tau)
133+
137134
delta_trans = tt.dot(delta, chol_tau)
138135
quaddist = (delta_trans ** 2).sum(axis=-1)
139136
logdet = -tt.sum(tt.log(diag))

0 commit comments

Comments
 (0)