Skip to content

Commit 768fa70

Browse files
committed
fix DTC conditional by clipping its covariance in [0,inf)
also use more descriptive names for vars
1 parent 1fa1ea7 commit 768fa70

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

pymc3/gp/gp.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,21 +290,22 @@ def __init__(self, mean_func=Zero(), cov_func=Constant(0.0), approx="FITC"):
290290

291291
def _build_prior(self, name, X, Xu, **kwargs):
292292
mu = self.mean_func(X) # (n,)
293-
L = cholesky(stabilize(self.cov_func(Xu))) # (m, m) \sqrt{K_u}
293+
Luu = cholesky(stabilize(self.cov_func(Xu))) # (m, m) \sqrt{K_u}
294294
shape = infer_shape(Xu, kwargs.pop("shape", None))
295295
v = pm.Normal(name + "_u_rotated_", mu=0.0, sd=1.0, shape=shape, **kwargs)
296-
u_ = self.mean_func(Xu) + tt.dot(L, v) # mean + chol method of MvGaussian
296+
u_ = self.mean_func(Xu) + tt.dot(Luu, v) # mean + chol method of MvGaussian
297297
u = pm.Deterministic(name+'_u', u_) # (m,) prior at inducing points
298-
Kuuiu = invert_dot(L, u) # (m,) K_{uu}^{-1} u
298+
Kuuiu = invert_dot(Luu, u) # (m,) K_{uu}^{-1} u
299299
Kfu = self.cov_func(X, Xu) # (n, m)
300300
f_ = mu + tt.dot(Kfu, Kuuiu) # (n, m) @ (m,) = (n,)
301301
if self.approx == 'DTC':
302302
f = pm.Deterministic("f", f_)
303303
elif self.approx == 'FITC':
304-
Qff_diag = project_inverse(Kfu, L, diag=True)
304+
Qff_diag = project_inverse(Kfu, Luu, diag=True)
305305
Kff_diag = self.cov_func.diag(X)
306306
# MvNormal with diagonal cov is Normal with sd=cov**0.5
307-
f = pm.Normal("f", mu=f_, sd=tt.sqrt(tt.clip(Kff_diag - Qff_diag, 0, np.inf)), shape=shape)
307+
sd = tt.sqrt(tt.clip(Kff_diag - Qff_diag, 0, np.inf))
308+
f = pm.Normal("f", mu=f_, sd=sd, shape=shape)
308309
return f
309310

310311
def prior(self, name, X, Xu, **kwargs):
@@ -393,7 +394,7 @@ def _build_conditional(self, Xnew, X, Xu, f, cov_total, mean_total):
393394
mus += tt.dot(Qsf, f / Lambda)
394395
Qss = project_inverse(Ksu, Luu)
395396
Kss = self.cov_func(Xnew)
396-
cov = Kss - Qss
397+
cov = tt.clip(Kss - Qss, 0, np.inf)
397398
if self.approx == 'FITC':
398399
cov -= tt.dot(Qsf, tt.transpose(Qsf / Lambda))
399400
return mus, cov

pymc3/tests/test_gp.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,10 @@ def setup_method(self):
701701
p = gp.conditional("p", Xnew)
702702
chol = np.linalg.cholesky(cov_func(X).eval())
703703
y_rotated = np.linalg.solve(chol, y - mean_func(X).eval())
704-
self.logp = model.logp({"p": pnew, 'f_rotated_': y_rotated})
704+
self.logp = model.logp({
705+
"p": pnew,
706+
'f_rotated_': y_rotated
707+
})
705708
self.X = X
706709
self.Xnew = Xnew
707710
self.y = y
@@ -718,7 +721,10 @@ def testApproximations(self, approx):
718721
p = gp.conditional("p", self.Xnew)
719722
chol = np.linalg.cholesky(cov_func(self.X).eval())
720723
y_rotated = np.linalg.solve(chol, self.y - mean_func(self.X).eval())
721-
model_params = {"f_u_rotated_": y_rotated, "p": self.pnew}
724+
model_params = {
725+
"f_u_rotated_": y_rotated,
726+
"p": self.pnew
727+
}
722728
if approx == 'FITC':
723729
model_params['f'] = self.y # need to specify as well since f ~ Normal(f_, diag(Kff-Qff))
724730
approx_logp = model.logp(model_params)

0 commit comments

Comments
 (0)