File tree Expand file tree Collapse file tree 1 file changed +4
-7
lines changed Expand file tree Collapse file tree 1 file changed +4
-7
lines changed Original file line number Diff line number Diff line change @@ -107,8 +107,6 @@ def _quaddist(self, value):
107
107
108
108
def _quaddist_chol (self , delta ):
109
109
chol_cov = self .chol_cov
110
- _ , k = delta .shape
111
- k = pm .floatX (k )
112
110
diag = tt .nlinalg .diag (chol_cov )
113
111
# Check if the covariance matrix is positive definite.
114
112
ok = tt .all (diag > 0 )
@@ -126,14 +124,13 @@ def _quaddist_cov(self, delta):
126
124
127
125
def _quaddist_tau (self , delta ):
128
126
chol_tau = self .chol_tau
129
- _ , k = delta .shape
130
- k = pm .floatX (k )
131
-
132
127
diag = tt .nlinalg .diag (chol_tau )
128
+ # Check if the precision matrix is positive definite.
133
129
ok = tt .all (diag > 0 )
134
-
130
+ # If not, replace the diagonal. We return -inf later, but
131
+ # need to prevent solve_lower from throwing an exception.
135
132
chol_tau = tt .switch (ok , chol_tau , 1 )
136
- diag = tt . nlinalg . diag ( chol_tau )
133
+
137
134
delta_trans = tt .dot (delta , chol_tau )
138
135
quaddist = (delta_trans ** 2 ).sum (axis = - 1 )
139
136
logdet = - tt .sum (tt .log (diag ))
You can’t perform that action at this time.
0 commit comments