Skip to content

Commit 3ae419b

Browse files
bwengalsAlexAndorra
authored andcommitted
fix test fails
1 parent 7bb5d05 commit 3ae419b

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

pymc/gp/hsgp_approx.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,14 @@ def approx_hsgp_hyperparams(
102102
based on recommendations from Ruitort-Mayol et. al.
103103
104104
In practice, you need to choose `c` large enough to handle the largest lengthscales,
105-
and `m` large enough to accommodate the smallest lengthscales. Use your prior on the
106-
lengthscale as guidance for setting the prior range. For example, if you believe
105+
and `m` large enough to accommodate the smallest lengthscales. Use your prior on the
106+
lengthscale as guidance for setting the prior range. For example, if you believe
107107
that 95% of the prior mass of the lengthscale is between 1 and 5, set the
108108
`lengthscale_range` to be [1, 5], or maybe a touch wider.
109109
110-
Also, be sure to pass in an `x` that is exemplary of the domain not just of your
110+
Also, be sure to pass in an `x` that is exemplary of the domain not just of your
111111
training data, but also where you intend to make predictions. For instance, if your
112-
training x values are from [0, 10], and you intend to predict from [7, 15], you can
112+
training x values are from [0, 10], and you intend to predict from [7, 15], you can
113113
pass in `x_range = [0, 15]`.
114114
115115
NB: These recommendations are based on a one-dimensional GP.
@@ -295,6 +295,7 @@ def __init__(
295295

296296
if parametrization is not None:
297297
parametrization = parametrization.lower().replace("-", "").replace("_", "")
298+
298299
if parametrization not in ["centered", "noncentered"]:
299300
raise ValueError("`parametrization` must be either 'centered' or 'noncentered'.")
300301

@@ -597,6 +598,7 @@ def __init__(
597598

598599
self._m = m
599600
self.scale = scale
601+
self._X_center = None
600602

601603
super().__init__(mean_func=mean_func, cov_func=cov_func)
602604

@@ -672,8 +674,8 @@ def prior_linearized(self, X: TensorLike):
672674
# Important: fix the computation of the midpoint of X.
673675
# If X is mutated later, the training midpoint will be subtracted, not the testing one.
674676
if self._X_center is None:
675-
self._X_center = (pt.max(Xs, axis=0) + pt.min(Xs, axis=0)).eval() / 2
676-
Xs = Xs - self._X_center # center for accurate computation
677+
self._X_center = (pt.max(X, axis=0) + pt.min(X, axis=0)).eval() / 2
678+
Xs = X - self._X_center # center for accurate computation
677679

678680
# Index Xs using input_dim and active_dims of covariance function
679681
Xs, _ = self.cov_func._slice(Xs)
@@ -715,7 +717,7 @@ def prior(self, name: str, X: TensorLike, dims: str | None = None): # type: ign
715717

716718
def _build_conditional(self, Xnew):
717719
try:
718-
beta, X_mean = self._beta, self._X_mean
720+
beta, X_center = self._beta, self._X_center
719721

720722
except AttributeError:
721723
raise ValueError(
@@ -724,7 +726,9 @@ def _build_conditional(self, Xnew):
724726

725727
Xnew, _ = self.cov_func._slice(Xnew)
726728

727-
phi_cos, phi_sin = calc_basis_periodic(Xnew - X_mean, self.cov_func.period, self._m, tl=pt)
729+
phi_cos, phi_sin = calc_basis_periodic(
730+
Xnew - X_center, self.cov_func.period, self._m, tl=pt
731+
)
728732
m = self._m
729733
J = pt.arange(0, m, 1)
730734
# rescale basis coefficients by the sqrt variance term

tests/gp/test_hsgp_approx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_mean_invariance(self):
129129
_ = pm.Data("X", X)
130130
cov_func = pm.gp.cov.ExpQuad(1, ls=3)
131131
gp = pm.gp.HSGP(m=[20], L=[10], cov_func=cov_func)
132-
_ = gp.prior_linearized(Xs=X)
132+
_ = gp.prior_linearized(X=X)
133133

134134
x_new = np.linspace(-10, 20, 100)[:, None]
135135
with model:

0 commit comments

Comments
 (0)