Skip to content

Commit 24a9e54

Browse files
committed
Fix computation of S in approx_hsgp_hyperparams
1 parent 1ff3b8c commit 24a9e54

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

pymc/gp/hsgp_approx.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232

3333

3434
def set_boundary(Xs: TensorLike, c: numbers.Real | TensorLike) -> np.ndarray:
35-
"""Set the boundary using the mean-subtracted `Xs` and `c`. `c` is usually a scalar
35+
"""Set the boundary using the `Xs` centered around 0 and `c`. `c` is usually a scalar
3636
multiplier greater than 1.0, but it may be one value per dimension or column of `Xs`.
3737
"""
38-
S = pt.max(pt.abs(Xs), axis=0)
38+
S = pt.max(pt.abs(Xs), axis=0) # important: the Xs should be centered around 0
3939
L = (c * S).eval() # eval() makes sure L is not changed with out-of-sample preds
4040
return L
4141

@@ -96,7 +96,7 @@ class HSGPParams(NamedTuple):
9696

9797

9898
def approx_hsgp_hyperparams(
99-
x_range: list[float], lengthscale_range: list[float], cov_func: str
99+
x: np.ndarray, lengthscale_range: list[float], cov_func: str
100100
) -> HSGPParams:
101101
"""Utility function that uses heuristics to recommend minimum `m` and `c` values,
102102
based on recommendations from Ruitort-Mayol et. al.
@@ -138,10 +138,12 @@ def approx_hsgp_hyperparams(
138138
- Ruitort-Mayol, G., Anderson, M., Solin, A., Vehtari, A. (2022).
139139
Practical Hilbert Space Approximate Bayesian Gaussian Processes for Probabilistic Programming
140140
"""
141-
if (x_range[0] >= x_range[1]) or (lengthscale_range[0] >= lengthscale_range[1]):
141+
if lengthscale_range[0] >= lengthscale_range[1]:
142142
raise ValueError("One of the boundaries out of order")
143143

144-
S = (x_range[1] - x_range[0]) / 2
144+
X_center = (np.max(x, axis=0) - np.min(x, axis=0)) / 2
145+
Xs = x - X_center
146+
S = np.max(np.abs(Xs), axis=0)
145147

146148
if cov_func.lower() == "expquad":
147149
a1, a2 = 3.2, 1.75
@@ -401,7 +403,7 @@ def prior_linearized(self, Xs: TensorLike):
401403
# If not provided, use Xs and c to set L
402404
if self._L is None:
403405
assert isinstance(self._c, numbers.Real | np.ndarray | pt.TensorVariable)
404-
self.L = pt.as_tensor(set_boundary(Xs, self._c))
406+
self.L = pt.as_tensor(set_boundary(Xs, self._c)) # Xs should be 0-centered
405407
else:
406408
self.L = self._L
407409

0 commit comments

Comments
 (0)