Skip to content

Commit 1ff3b8c

Browse files
committed
Use midpoint of X instead of mean
1 parent 3da91cc commit 1ff3b8c

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

pymc/gp/hsgp_approx.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def __init__(
302302
self._L = pt.as_tensor(L).eval() # make sure L cannot be changed
303303
self._c = c
304304
self._parametrization = parametrization
305-
self._X_mean = None
305+
self._X_center = None
306306

307307
super().__init__(mean_func=mean_func, cov_func=cov_func)
308308

@@ -389,11 +389,11 @@ def prior_linearized(self, Xs: TensorLike):
389389
with model:
390390
ppc = pm.sample_posterior_predictive(idata, var_names=["f"])
391391
"""
392-
# Important: fix the computation of the mean. If X is mutated later,
393-
# the training mean will be subtracted, not the testing mean.
394-
if self._X_mean is None:
395-
self._X_mean = pt.mean(Xs, axis=0).eval()
396-
Xs = Xs - self._X_mean # mean-center for accurate computation
392+
# Important: fix the computation of the midpoint of X.
393+
# If X is mutated later, the training midpoint will be subtracted, not the testing one.
394+
if self._X_center is None:
395+
self._X_center = (pt.max(Xs, axis=0) - pt.min(Xs, axis=0)).eval() / 2
396+
Xs = Xs - self._X_center # center for accurate computation
397397

398398
# Index Xs using input_dim and active_dims of covariance function
399399
Xs, _ = self.cov_func._slice(Xs)
@@ -457,7 +457,7 @@ def prior(
457457

458458
def _build_conditional(self, Xnew):
459459
try:
460-
beta, X_mean = self._beta, self._X_mean
460+
beta, X_center = self._beta, self._X_center
461461

462462
if self._parametrization == "noncentered":
463463
sqrt_psd = self._sqrt_psd
@@ -470,7 +470,7 @@ def _build_conditional(self, Xnew):
470470
Xnew, _ = self.cov_func._slice(Xnew)
471471

472472
eigvals = calc_eigenvalues(self.L, self._m)
473-
phi = calc_eigenvectors(Xnew - X_mean, self.L, eigvals, self._m)
473+
phi = calc_eigenvectors(Xnew - X_center, self.L, eigvals, self._m)
474474
i = int(self._drop_first is True)
475475

476476
if self._parametrization == "noncentered":

0 commit comments

Comments
 (0)