@@ -290,21 +290,22 @@ def __init__(self, mean_func=Zero(), cov_func=Constant(0.0), approx="FITC"):
290
290
291
291
def _build_prior (self , name , X , Xu , ** kwargs ):
292
292
mu = self .mean_func (X ) # (n,)
293
- L = cholesky (stabilize (self .cov_func (Xu ))) # (m, m) \sqrt{K_u}
293
+ Luu = cholesky (stabilize (self .cov_func (Xu ))) # (m, m) \sqrt{K_u}
294
294
shape = infer_shape (Xu , kwargs .pop ("shape" , None ))
295
295
v = pm .Normal (name + "_u_rotated_" , mu = 0.0 , sd = 1.0 , shape = shape , ** kwargs )
296
- u_ = self .mean_func (Xu ) + tt .dot (L , v ) # mean + chol method of MvGaussian
296
+ u_ = self .mean_func (Xu ) + tt .dot (Luu , v ) # mean + chol method of MvGaussian
297
297
u = pm .Deterministic (name + '_u' , u_ ) # (m,) prior at inducing points
298
- Kuuiu = invert_dot (L , u ) # (m,) K_{uu}^{-1} u
298
+ Kuuiu = invert_dot (Luu , u ) # (m,) K_{uu}^{-1} u
299
299
Kfu = self .cov_func (X , Xu ) # (n, m)
300
300
f_ = mu + tt .dot (Kfu , Kuuiu ) # (n, m) @ (m,) = (n,)
301
301
if self .approx == 'DTC' :
302
302
f = pm .Deterministic ("f" , f_ )
303
303
elif self .approx == 'FITC' :
304
- Qff_diag = project_inverse (Kfu , L , diag = True )
304
+ Qff_diag = project_inverse (Kfu , Luu , diag = True )
305
305
Kff_diag = self .cov_func .diag (X )
306
306
# MvNormal with diagonal cov is Normal with sd=cov**0.5
307
- f = pm .Normal ("f" , mu = f_ , sd = tt .sqrt (tt .clip (Kff_diag - Qff_diag , 0 , np .inf )), shape = shape )
307
+ sd = tt .sqrt (tt .clip (Kff_diag - Qff_diag , 0 , np .inf ))
308
+ f = pm .Normal ("f" , mu = f_ , sd = sd , shape = shape )
308
309
return f
309
310
310
311
def prior (self , name , X , Xu , ** kwargs ):
@@ -393,7 +394,7 @@ def _build_conditional(self, Xnew, X, Xu, f, cov_total, mean_total):
393
394
mus += tt .dot (Qsf , f / Lambda )
394
395
Qss = project_inverse (Ksu , Luu )
395
396
Kss = self .cov_func (Xnew )
396
- cov = Kss - Qss
397
+ cov = tt . clip ( Kss - Qss , 0 , np . inf )
397
398
if self .approx == 'FITC' :
398
399
cov -= tt .dot (Qsf , tt .transpose (Qsf / Lambda ))
399
400
return mus , cov
0 commit comments