17
17
18
18
from collections .abc import Sequence
19
19
from types import ModuleType
20
- from typing import NamedTuple
21
20
22
21
import numpy as np
23
22
import pytensor .tensor as pt
@@ -89,15 +88,9 @@ def calc_basis_periodic(
89
88
return phi_cos , phi_sin
90
89
91
90
92
- class HSGPParams (NamedTuple ):
93
- m : int
94
- c : float
95
- S : float
96
-
97
-
98
91
def approx_hsgp_hyperparams (
99
92
x_range : list [float ], lengthscale_range : list [float ], cov_func : str
100
- ) -> HSGPParams :
93
+ ) -> tuple [ int , float ] :
101
94
"""Utility function that uses heuristics to recommend minimum `m` and `c` values,
102
95
based on recommendations from Ruitort-Mayol et. al.
103
96
@@ -107,10 +100,10 @@ def approx_hsgp_hyperparams(
107
100
that 95% of the prior mass of the lengthscale is between 1 and 5, set the
108
101
`lengthscale_range` to be [1, 5], or maybe a touch wider.
109
102
110
- Also, be sure to pass in an `x ` that is exemplary of the domain not just of your
103
+ Also, be sure to pass in an `x_range ` that is exemplary of the domain not just of your
111
104
training data, but also where you intend to make predictions. For instance, if your
112
- training x values are from [0, 10], and you intend to predict from [7, 15], you can
113
- pass in `x_range = [0, 15]`.
105
+ training x values are from [0, 10], and you intend to predict from [7, 15], the narrowest
106
+ `x_range` you should pass in would be `x_range = [0, 15]`.
114
107
115
108
NB: These recommendations are based on a one-dimensional GP.
116
109
@@ -126,15 +119,11 @@ def approx_hsgp_hyperparams(
126
119
127
120
Returns
128
121
-------
129
- HSGPParams
130
- A named tuple containing the recommended values for `m`, `c`, and `S`.
131
- - `m` : int
132
- Number of basis vectors. Increasing it helps approximate smaller lengthscales, but increases computational cost.
133
- - `c` : float
134
- Scaling factor such that L = c * S, where L is the boundary of the approximation.
135
- Increasing it helps approximate larger lengthscales, but may require increasing m.
136
- - `S` : float
137
- The value of `S`, which is half the range, or radius, of `x`.
122
+ - `m` : int
123
+ Number of basis vectors. Increasing it helps approximate smaller lengthscales, but increases computational cost.
124
+ - `c` : float
125
+ Scaling factor such that L = c * S, where L is the boundary of the approximation.
126
+ Increasing it helps approximate larger lengthscales, but may require increasing m.
138
127
139
128
Raises
140
129
------
@@ -171,7 +160,7 @@ def approx_hsgp_hyperparams(
171
160
c = max (a1 * (lengthscale_range [1 ] / S ), 1.2 )
172
161
m = int (a2 * c / (lengthscale_range [0 ] / S ))
173
162
174
- return HSGPParams ( m = m , c = c , S = S )
163
+ return m , c
175
164
176
165
177
166
class HSGP (Base ):
0 commit comments