Skip to content

Commit 426b5c6

Browse files
committed
fix comments, make pass jitter through correctly, get rid of is_observed arg
1 parent 5740f2e commit 426b5c6

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

pymc/gp/gp.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def __add__(self, other):
691691
new_gp.approx = self.approx
692692
return new_gp
693693

694-
def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
694+
def _build_marginal_likelihood_loglik(self, y, X, Xu, sigma, jitter):
695695
sigma2 = at.square(sigma)
696696
Kuu = self.cov_func(Xu)
697697
Kuf = self.cov_func(Xu, X)
@@ -720,9 +720,7 @@ def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
720720
quadratic = 0.5 * (at.dot(r, r_l) - at.dot(c, c))
721721
return -1.0 * (constant + logdet + quadratic + trace)
722722

723-
def marginal_likelihood(
724-
self, name, X, Xu, y, noise=None, is_observed=True, jitter=JITTER_DEFAULT, **kwargs
725-
):
723+
def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT, **kwargs):
726724
R"""
727725
Returns the approximate marginal likelihood distribution, given the input
728726
locations `X`, inducing point locations `Xu`, data `y`, and white noise
@@ -759,8 +757,8 @@ def marginal_likelihood(
759757
else:
760758
self.sigma = noise
761759

762-
approx_logp = self._build_marginal_likelihood_logp(y, X, Xu, noise, JITTER_DEFAULT)
763-
pm.Potential(f"marginalapprox_logp_{name}", approx_logp)
760+
approx_loglik = self._build_marginal_likelihood_loglik(y, X, Xu, noise, jitter)
761+
pm.Potential(f"marginalapprox_loglik_{name}", approx_loglik, **kwargs)
764762

765763
def _build_conditional(
766764
self, Xnew, pred_noise, diag, X, Xu, y, sigma, cov_total, mean_total, jitter

pymc/tests/test_gp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ class TestMarginalVsMarginalApprox:
852852
def setup_method(self):
853853
self.sigma = 0.1
854854
self.x = np.linspace(-5, 5, 30)
855-
self.y = 0.25 * self.x + self.sigma * np.random.randn(len(self.x))
855+
self.y = np.random.normal(0.25 * self.x, self.sigma)
856856
with pm.Model() as model:
857857
cov_func = pm.gp.cov.Linear(1, c=0.0)
858858
c = pm.Normal("c", mu=20.0, sigma=100.0) # far from true value

0 commit comments

Comments
 (0)