Skip to content

Commit ab6e01d

Browse files
committed
refactored test suite to see prior and cond match separately. But had to use relative 5% error for DTC conditional
1 parent c3cdd2c commit ab6e01d

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

pymc3/gp/gp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ def _build_prior(self, name, X, Xu, **kwargs):
304304
Qff_diag = project_inverse(Kfu, Luu, diag=True)
305305
Kff_diag = self.cov_func.diag(X)
306306
# MvNormal with diagonal cov is Normal with sd=cov**0.5
307-
sd = tt.sqrt(tt.clip(Kff_diag - Qff_diag, 0, np.inf))
308-
f = pm.Normal(name, mu=f_, sd=sd, shape=shape)
307+
var = tt.clip(Kff_diag - Qff_diag, 0, np.inf)
308+
f = pm.Normal(name, mu=f_, tau=tt.inv(var), shape=shape)
309309
return f
310310

311311
def prior(self, name, X, Xu, **kwargs):

pymc3/tests/test_gp.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -698,21 +698,39 @@ def setup_method(self):
698698
mean_func = pm.gp.mean.Constant(0.5)
699699
gp = pm.gp.Latent(mean_func, cov_func)
700700
f = gp.prior("f", X)
701-
p = gp.conditional("p", Xnew)
702701
chol = np.linalg.cholesky(cov_func(X).eval())
703702
y_rotated = np.linalg.solve(chol, y - mean_func(X).eval())
704-
self.logp = model.logp({
705-
"p": pnew,
706-
'f_rotated_': y_rotated
707-
})
703+
logp_params = { 'f_rotated_': y_rotated }
704+
self.logp_prior = model.logp(logp_params)
705+
with model:
706+
p = gp.conditional("p", Xnew)
707+
logp_params['p'] = pnew
708+
self.logp_coditional = model.logp(logp_params)
708709
self.X = X
709710
self.Xnew = Xnew
710711
self.y = y
711712
self.pnew = pnew
712713
self.gp = gp
713714

714715
@pytest.mark.parametrize('approx', ['FITC', 'DTC'])
715-
def testApproximations(self, approx):
716+
def testPriorApproximations(self, approx):
717+
with pm.Model() as model:
718+
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
719+
mean_func = pm.gp.mean.Constant(0.5)
720+
gp = pm.gp.LatentSparse(mean_func, cov_func, approx=approx)
721+
f = gp.prior("f", self.X, self.X)
722+
chol = np.linalg.cholesky(cov_func(self.X).eval())
723+
y_rotated = np.linalg.solve(chol, self.y - mean_func(self.X).eval())
724+
model_params = {
725+
"f_u_rotated_": y_rotated,
726+
}
727+
if approx == 'FITC':
728+
model_params['f'] = self.y # need to specify as well since f ~ Normal(f_, diag(Kff-Qff))
729+
approx_logp = model.logp(model_params)
730+
npt.assert_allclose(approx_logp, self.logp_prior, atol=0, rtol=1e-2)
731+
732+
@pytest.mark.parametrize('approx', ['FITC', 'DTC'])
733+
def testConditionalApproximations(self, approx):
716734
with pm.Model() as model:
717735
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
718736
mean_func = pm.gp.mean.Constant(0.5)
@@ -728,7 +746,7 @@ def testApproximations(self, approx):
728746
if approx == 'FITC':
729747
model_params['f'] = self.y # need to specify as well since f ~ Normal(f_, diag(Kff-Qff))
730748
approx_logp = model.logp(model_params)
731-
npt.assert_allclose(approx_logp, self.logp, atol=0, rtol=1e-2)
749+
npt.assert_allclose(approx_logp, self.logp_coditional, atol=0, rtol=5e-2)
732750

733751

734752
class TestMarginalVsMarginalSparse(object):

0 commit comments

Comments
 (0)