Skip to content

Commit e07eea7

Browse files
authored
Gp sigma parameterize 6094 (#6145)
* Standardizing 'noise' parameter to be called 'sigma' * reformatting with black / pre-commit * removing noise param in tests / removing related futurewarnings from tests * looking for the warnings and the value error * capitalize acronyms and "noise_func" instead of "sigma" Co-authored-by: Will Dean <[email protected]>
1 parent c5ae227 commit e07eea7

File tree

2 files changed

+139
-38
lines changed

2 files changed

+139
-38
lines changed

pymc/gp/gp.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,26 @@
3737
__all__ = ["Latent", "Marginal", "TP", "MarginalApprox", "LatentKron", "MarginalKron"]
3838

3939

40+
_noise_deprecation_warning = (
41+
"The 'noise' parameter has been been changed to 'sigma' "
42+
"in order to standardize the GP API and will be "
43+
"deprecated in future releases."
44+
)
45+
46+
47+
def _handle_sigma_noise_parameters(sigma, noise):
48+
"""Helper function for transition of 'noise' parameter to be named 'sigma'."""
49+
50+
if (sigma is None and noise is None) or (sigma is not None and noise is not None):
51+
raise ValueError("'sigma' argument must be specified.")
52+
53+
if sigma is None:
54+
warnings.warn(_noise_deprecation_warning, FutureWarning)
55+
return noise
56+
57+
return sigma
58+
59+
4060
class Base:
4161
R"""
4262
Base class.
@@ -218,7 +238,7 @@ def conditional(self, name, Xnew, given=None, jitter=JITTER_DEFAULT, **kwargs):
218238
Xnew: array-like
219239
Function input values.
220240
given: dict
221-
Can optionally take as key value pairs: `X`, `y`, `noise`,
241+
Can optionally take as key value pairs: `X`, `y`,
222242
and `gp`. See the section in the documentation on additive GP
223243
models in PyMC for more information.
224244
jitter: scalar
@@ -359,7 +379,7 @@ def conditional(self, name, Xnew, jitter=JITTER_DEFAULT, **kwargs):
359379
return pm.MvStudentT(name, nu=nu2, mu=mu, cov=cov, **kwargs)
360380

361381

362-
@conditioned_vars(["X", "y", "noise"])
382+
@conditioned_vars(["X", "y", "sigma"])
363383
class Marginal(Base):
364384
R"""
365385
Marginal Gaussian process.
@@ -393,7 +413,7 @@ class Marginal(Base):
393413
394414
# Place a GP prior over the function f.
395415
sigma = pm.HalfCauchy("sigma", beta=3)
396-
y_ = gp.marginal_likelihood("y", X=X, y=y, noise=sigma)
416+
y_ = gp.marginal_likelihood("y", X=X, y=y, sigma=sigma)
397417
398418
...
399419
@@ -405,15 +425,15 @@ class Marginal(Base):
405425
fcond = gp.conditional("fcond", Xnew=Xnew)
406426
"""
407427

408-
def _build_marginal_likelihood(self, X, noise, jitter):
428+
def _build_marginal_likelihood(self, X, noise_func, jitter):
409429
mu = self.mean_func(X)
410430
Kxx = self.cov_func(X)
411-
Knx = noise(X)
431+
Knx = noise_func(X)
412432
cov = Kxx + Knx
413433
return mu, stabilize(cov, jitter)
414434

415435
def marginal_likelihood(
416-
self, name, X, y, noise, jitter=JITTER_DEFAULT, is_observed=True, **kwargs
436+
self, name, X, y, sigma=None, noise=None, jitter=JITTER_DEFAULT, is_observed=True, **kwargs
417437
):
418438
R"""
419439
Returns the marginal likelihood distribution, given the input
@@ -435,23 +455,25 @@ def marginal_likelihood(
435455
y: array-like
436456
Data that is the sum of the function with the GP prior and Gaussian
437457
noise. Must have shape `(n, )`.
438-
noise: scalar, Variable, or Covariance
458+
sigma: scalar, Variable, or Covariance
439459
Standard deviation of the Gaussian noise. Can also be a Covariance for
440460
non-white noise.
461+
noise: scalar, Variable, or Covariance
462+
Previous parameterization of `sigma`.
441463
jitter: scalar
442464
A small correction added to the diagonal of positive semi-definite
443465
covariance matrices to ensure numerical stability.
444466
**kwargs
445467
Extra keyword arguments that are passed to `MvNormal` distribution
446468
constructor.
447469
"""
470+
sigma = _handle_sigma_noise_parameters(sigma=sigma, noise=noise)
448471

449-
if not isinstance(noise, Covariance):
450-
noise = pm.gp.cov.WhiteNoise(noise)
451-
mu, cov = self._build_marginal_likelihood(X, noise, jitter)
472+
noise_func = sigma if isinstance(sigma, Covariance) else pm.gp.cov.WhiteNoise(sigma)
473+
mu, cov = self._build_marginal_likelihood(X=X, noise_func=noise_func, jitter=jitter)
452474
self.X = X
453475
self.y = y
454-
self.noise = noise
476+
self.sigma = noise_func
455477
if is_observed:
456478
return pm.MvNormal(name, mu=mu, cov=cov, observed=y, **kwargs)
457479
else:
@@ -472,20 +494,24 @@ def _get_given_vals(self, given):
472494
else:
473495
cov_total = self.cov_func
474496
mean_total = self.mean_func
475-
if all(val in given for val in ["X", "y", "noise"]):
476-
X, y, noise = given["X"], given["y"], given["noise"]
477-
if not isinstance(noise, Covariance):
478-
noise = pm.gp.cov.WhiteNoise(noise)
497+
498+
if "noise" in given:
499+
warnings.warn(_noise_deprecation_warning, FutureWarning)
500+
given["sigma"] = given["noise"]
501+
502+
if all(val in given for val in ["X", "y", "sigma"]):
503+
X, y, sigma = given["X"], given["y"], given["sigma"]
504+
noise_func = sigma if isinstance(sigma, Covariance) else pm.gp.cov.WhiteNoise(sigma)
479505
else:
480-
X, y, noise = self.X, self.y, self.noise
481-
return X, y, noise, cov_total, mean_total
506+
X, y, noise_func = self.X, self.y, self.sigma
507+
return X, y, noise_func, cov_total, mean_total
482508

483509
def _build_conditional(
484-
self, Xnew, pred_noise, diag, X, y, noise, cov_total, mean_total, jitter
510+
self, Xnew, pred_noise, diag, X, y, noise_func, cov_total, mean_total, jitter
485511
):
486512
Kxx = cov_total(X)
487513
Kxs = self.cov_func(X, Xnew)
488-
Knx = noise(X)
514+
Knx = noise_func(X)
489515
rxx = y - mean_total(X)
490516
L = cholesky(stabilize(Kxx, jitter) + Knx)
491517
A = solve_lower(L, Kxs)
@@ -495,13 +521,13 @@ def _build_conditional(
495521
Kss = self.cov_func(Xnew, diag=True)
496522
var = Kss - at.sum(at.square(A), 0)
497523
if pred_noise:
498-
var += noise(Xnew, diag=True)
524+
var += noise_func(Xnew, diag=True)
499525
return mu, var
500526
else:
501527
Kss = self.cov_func(Xnew)
502528
cov = Kss - at.dot(at.transpose(A), A)
503529
if pred_noise:
504-
cov += noise(Xnew)
530+
cov += noise_func(Xnew)
505531
return mu, cov if pred_noise else stabilize(cov, jitter)
506532

507533
def conditional(
@@ -531,7 +557,7 @@ def conditional(
531557
Whether or not observation noise is included in the conditional.
532558
Default is `False`.
533559
given: dict
534-
Can optionally take as key value pairs: `X`, `y`, `noise`,
560+
Can optionally take as key value pairs: `X`, `y`, `sigma`,
535561
and `gp`. See the section in the documentation on additive GP
536562
models in PyMC for more information.
537563
jitter: scalar
@@ -720,7 +746,9 @@ def _build_marginal_likelihood_loglik(self, y, X, Xu, sigma, jitter):
720746
quadratic = 0.5 * (at.dot(r, r_l) - at.dot(c, c))
721747
return -1.0 * (constant + logdet + quadratic + trace)
722748

723-
def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT, **kwargs):
749+
def marginal_likelihood(
750+
self, name, X, Xu, y, sigma=None, noise=None, jitter=JITTER_DEFAULT, **kwargs
751+
):
724752
R"""
725753
Returns the approximate marginal likelihood distribution, given the input
726754
locations `X`, inducing point locations `Xu`, data `y`, and white noise
@@ -738,8 +766,10 @@ def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT,
738766
y: array-like
739767
Data that is the sum of the function with the GP prior and Gaussian
740768
noise. Must have shape `(n, )`.
741-
noise: scalar, Variable
769+
sigma: scalar, Variable
742770
Standard deviation of the Gaussian noise.
771+
noise: scalar, Variable
772+
Previous parameterization of `sigma`
743773
jitter: scalar
744774
A small correction added to the diagonal of positive semi-definite
745775
covariance matrices to ensure numerical stability.
@@ -752,12 +782,11 @@ def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT,
752782
self.Xu = Xu
753783
self.y = y
754784

755-
if noise is None:
756-
raise ValueError("noise argument must be specified")
757-
else:
758-
self.sigma = noise
785+
self.sigma = _handle_sigma_noise_parameters(sigma=sigma, noise=noise)
759786

760-
approx_loglik = self._build_marginal_likelihood_loglik(y, X, Xu, noise, jitter)
787+
approx_loglik = self._build_marginal_likelihood_loglik(
788+
y=self.y, X=self.X, Xu=self.Xu, sigma=self.sigma, jitter=jitter
789+
)
761790
pm.Potential(f"marginalapprox_loglik_{name}", approx_loglik, **kwargs)
762791

763792
def _build_conditional(
@@ -828,7 +857,7 @@ def conditional(
828857
Whether or not observation noise is included in the conditional.
829858
Default is `False`.
830859
given: dict
831-
Can optionally take as key value pairs: `X`, `Xu`, `y`, `noise`,
860+
Can optionally take as key value pairs: `X`, `Xu`, `y`, `sigma`,
832861
and `gp`. See the section in the documentation on additive GP
833862
models in PyMC for more information.
834863
jitter: scalar

pymc/tests/gp/test_gp.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,78 @@
2424
from pymc.math import cartesian
2525

2626

27+
class TestSigmaParams:
28+
def setup_method(self):
29+
"""Common setup."""
30+
self.x = np.linspace(-5, 5, 30)[:, None]
31+
self.xu = np.linspace(-5, 5, 10)[:, None]
32+
self.y = np.random.normal(0.25 * self.x, 0.1)
33+
34+
with pm.Model() as self.model:
35+
cov_func = pm.gp.cov.Linear(1, c=0.0)
36+
c = pm.Normal("c", mu=20.0, sigma=100.0)
37+
mean_func = pm.gp.mean.Constant(c)
38+
self.gp = self.gp_implementation(mean_func=mean_func, cov_func=cov_func)
39+
self.sigma = pm.HalfNormal("sigma", sigma=100)
40+
41+
42+
class TestMarginalSigmaParams(TestSigmaParams):
43+
R"""Tests for the deprecation warnings and raising ValueError."""
44+
45+
gp_implementation = pm.gp.Marginal
46+
47+
def test_catch_warnings(self):
48+
"""Warning from using the old noise parameter."""
49+
with self.model:
50+
with pytest.warns(FutureWarning):
51+
self.gp.marginal_likelihood("lik_noise", X=self.x, y=self.y, noise=self.sigma)
52+
53+
with pytest.warns(FutureWarning):
54+
self.gp.conditional(
55+
"cond_noise",
56+
Xnew=self.x,
57+
given={
58+
"noise": self.sigma,
59+
},
60+
)
61+
62+
def test_raise_value_error(self):
63+
"""Either both or neither parameter is specified."""
64+
with self.model:
65+
with pytest.raises(ValueError):
66+
self.gp.marginal_likelihood(
67+
"like_both", X=self.x, y=self.y, noise=self.sigma, sigma=self.sigma
68+
)
69+
70+
with pytest.raises(ValueError):
71+
self.gp.marginal_likelihood("like_neither", X=self.x, y=self.y)
72+
73+
74+
class TestMarginalApproxSigmaParams(TestSigmaParams):
75+
R"""Tests for the deprecation warnings and raising ValueError"""
76+
77+
gp_implementation = pm.gp.MarginalApprox
78+
79+
def test_catch_warnings(self):
80+
"""Warning from using the old noise parameter."""
81+
with self.model:
82+
with pytest.warns(FutureWarning):
83+
self.gp.marginal_likelihood(
84+
"lik_noise", X=self.x, Xu=self.xu, y=self.y, noise=self.sigma
85+
)
86+
87+
def test_raise_value_error(self):
88+
"""Either both or neither parameter is specified."""
89+
with self.model:
90+
with pytest.raises(ValueError):
91+
self.gp.marginal_likelihood(
92+
"like_both", X=self.x, Xu=self.xu, y=self.y, noise=self.sigma, sigma=self.sigma
93+
)
94+
95+
with pytest.raises(ValueError):
96+
self.gp.marginal_likelihood("like_neither", X=self.x, Xu=self.xu, y=self.y)
97+
98+
2799
class TestMarginalVsMarginalApprox:
28100
R"""
29101
Compare test fits of models Marginal and MarginalApprox.
@@ -113,20 +185,20 @@ def testAdditiveMarginal(self):
113185
gp3 = pm.gp.Marginal(mean_func=self.means[2], cov_func=self.covs[2])
114186

115187
gpsum = gp1 + gp2 + gp3
116-
fsum = gpsum.marginal_likelihood("f", self.X, self.y, noise=self.noise)
188+
fsum = gpsum.marginal_likelihood("f", self.X, self.y, sigma=self.noise)
117189
model1_logp = model1.compile_logp()({})
118190

119191
with pm.Model() as model2:
120192
gptot = pm.gp.Marginal(
121193
mean_func=reduce(add, self.means), cov_func=reduce(add, self.covs)
122194
)
123-
fsum = gptot.marginal_likelihood("f", self.X, self.y, noise=self.noise)
195+
fsum = gptot.marginal_likelihood("f", self.X, self.y, sigma=self.noise)
124196
model2_logp = model2.compile_logp()({})
125197
npt.assert_allclose(model1_logp, model2_logp, atol=0, rtol=1e-2)
126198

127199
with model1:
128200
fp1 = gpsum.conditional(
129-
"fp1", self.Xnew, given={"X": self.X, "y": self.y, "noise": self.noise, "gp": gpsum}
201+
"fp1", self.Xnew, given={"X": self.X, "y": self.y, "sigma": self.noise, "gp": gpsum}
130202
)
131203
with model2:
132204
fp2 = gptot.conditional("fp2", self.Xnew)
@@ -152,14 +224,14 @@ def testAdditiveMarginalApprox(self, approx):
152224
)
153225

154226
gpsum = gp1 + gp2 + gp3
155-
fsum = gpsum.marginal_likelihood("f", self.X, Xu, self.y, noise=sigma)
227+
fsum = gpsum.marginal_likelihood("f", self.X, Xu, self.y, sigma=sigma)
156228
model1_logp = model1.compile_logp()({})
157229

158230
with pm.Model() as model2:
159231
gptot = pm.gp.MarginalApprox(
160232
mean_func=reduce(add, self.means), cov_func=reduce(add, self.covs), approx=approx
161233
)
162-
fsum = gptot.marginal_likelihood("f", self.X, Xu, self.y, noise=sigma)
234+
fsum = gptot.marginal_likelihood("f", self.X, Xu, self.y, sigma=sigma)
163235
model2_logp = model2.compile_logp()({})
164236
npt.assert_allclose(model1_logp, model2_logp, atol=0, rtol=1e-2)
165237

@@ -233,7 +305,7 @@ def testAdditiveTypeRaises2(self):
233305

234306
class TestMarginalVsLatent:
235307
R"""
236-
Compare the logp of models Marginal, noise=0 and Latent.
308+
Compare the logp of models Marginal, sigma=0 and Latent.
237309
"""
238310

239311
def setup_method(self):
@@ -245,7 +317,7 @@ def setup_method(self):
245317
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
246318
mean_func = pm.gp.mean.Constant(0.5)
247319
gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
248-
f = gp.marginal_likelihood("f", X, y, noise=0.0)
320+
f = gp.marginal_likelihood("f", X, y, sigma=0.0)
249321
p = gp.conditional("p", Xnew)
250322
self.logp = model.compile_logp()({"p": pnew})
251323
self.X = X
@@ -422,7 +494,7 @@ def setup_method(self):
422494
cov_func = pm.gp.cov.Kron(self.cov_funcs)
423495
self.mean = pm.gp.mean.Constant(0.5)
424496
gp = pm.gp.Marginal(mean_func=self.mean, cov_func=cov_func)
425-
f = gp.marginal_likelihood("f", self.X, self.y, noise=self.sigma)
497+
f = gp.marginal_likelihood("f", self.X, self.y, sigma=self.sigma)
426498
p = gp.conditional("p", self.Xnew)
427499
self.mu, self.cov = gp.predict(self.Xnew)
428500
self.logp = model.compile_logp()({"p": self.pnew})

0 commit comments

Comments
 (0)