Skip to content

Commit 5281d33

Browse files
bwengalsAlexAndorra
authored andcommitted
remove c, HSGPParams tuple, fix docstring
1 parent 3ae419b commit 5281d33

File tree

1 file changed

+10
-21
lines changed

1 file changed

+10
-21
lines changed

pymc/gp/hsgp_approx.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from collections.abc import Sequence
1919
from types import ModuleType
20-
from typing import NamedTuple
2120

2221
import numpy as np
2322
import pytensor.tensor as pt
@@ -89,15 +88,9 @@ def calc_basis_periodic(
8988
return phi_cos, phi_sin
9089

9190

92-
class HSGPParams(NamedTuple):
93-
m: int
94-
c: float
95-
S: float
96-
97-
9891
def approx_hsgp_hyperparams(
9992
x_range: list[float], lengthscale_range: list[float], cov_func: str
100-
) -> HSGPParams:
93+
) -> tuple[int, float]:
10194
"""Utility function that uses heuristics to recommend minimum `m` and `c` values,
10295
based on recommendations from Ruitort-Mayol et. al.
10396
@@ -107,10 +100,10 @@ def approx_hsgp_hyperparams(
107100
that 95% of the prior mass of the lengthscale is between 1 and 5, set the
108101
`lengthscale_range` to be [1, 5], or maybe a touch wider.
109102
110-
Also, be sure to pass in an `x` that is exemplary of the domain not just of your
103+
Also, be sure to pass in an `x_range` that is exemplary of the domain not just of your
111104
training data, but also where you intend to make predictions. For instance, if your
112-
training x values are from [0, 10], and you intend to predict from [7, 15], you can
113-
pass in `x_range = [0, 15]`.
105+
training x values are from [0, 10], and you intend to predict from [7, 15], the narrowest
106+
`x_range` you should pass in would be `x_range = [0, 15]`.
114107
115108
NB: These recommendations are based on a one-dimensional GP.
116109
@@ -126,15 +119,11 @@ def approx_hsgp_hyperparams(
126119
127120
Returns
128121
-------
129-
HSGPParams
130-
A named tuple containing the recommended values for `m`, `c`, and `S`.
131-
- `m` : int
132-
Number of basis vectors. Increasing it helps approximate smaller lengthscales, but increases computational cost.
133-
- `c` : float
134-
Scaling factor such that L = c * S, where L is the boundary of the approximation.
135-
Increasing it helps approximate larger lengthscales, but may require increasing m.
136-
- `S` : float
137-
The value of `S`, which is half the range, or radius, of `x`.
122+
- `m` : int
123+
Number of basis vectors. Increasing it helps approximate smaller lengthscales, but increases computational cost.
124+
- `c` : float
125+
Scaling factor such that L = c * S, where L is the boundary of the approximation.
126+
Increasing it helps approximate larger lengthscales, but may require increasing m.
138127
139128
Raises
140129
------
@@ -171,7 +160,7 @@ def approx_hsgp_hyperparams(
171160
c = max(a1 * (lengthscale_range[1] / S), 1.2)
172161
m = int(a2 * c / (lengthscale_range[0] / S))
173162

174-
return HSGPParams(m=m, c=c, S=S)
163+
return m, c
175164

176165

177166
class HSGP(Base):

0 commit comments

Comments
 (0)