37
37
__all__ = ["Latent" , "Marginal" , "TP" , "MarginalApprox" , "LatentKron" , "MarginalKron" ]
38
38
39
39
40
+ _noise_deprecation_warning = (
41
+ "The 'noise' parameter has been been changed to 'sigma' "
42
+ "in order to standardize the GP API and will be "
43
+ "deprecated in future releases."
44
+ )
45
+
46
+
47
+ def _handle_sigma_noise_parameters (sigma , noise ):
48
+ """Helper function for transition of 'noise' parameter to be named 'sigma'."""
49
+
50
+ if (sigma is None and noise is None ) or (sigma is not None and noise is not None ):
51
+ raise ValueError ("'sigma' argument must be specified." )
52
+
53
+ if sigma is None :
54
+ warnings .warn (_noise_deprecation_warning , FutureWarning )
55
+ return noise
56
+
57
+ return sigma
58
+
59
+
40
60
class Base :
41
61
R"""
42
62
Base class.
@@ -218,7 +238,7 @@ def conditional(self, name, Xnew, given=None, jitter=JITTER_DEFAULT, **kwargs):
218
238
Xnew: array-like
219
239
Function input values.
220
240
given: dict
221
- Can optionally take as key value pairs: `X`, `y`, `noise`,
241
+ Can optionally take as key value pairs: `X`, `y`,
222
242
and `gp`. See the section in the documentation on additive GP
223
243
models in PyMC for more information.
224
244
jitter: scalar
@@ -359,7 +379,7 @@ def conditional(self, name, Xnew, jitter=JITTER_DEFAULT, **kwargs):
359
379
return pm .MvStudentT (name , nu = nu2 , mu = mu , cov = cov , ** kwargs )
360
380
361
381
362
- @conditioned_vars (["X" , "y" , "noise " ])
382
+ @conditioned_vars (["X" , "y" , "sigma " ])
363
383
class Marginal (Base ):
364
384
R"""
365
385
Marginal Gaussian process.
@@ -393,7 +413,7 @@ class Marginal(Base):
393
413
394
414
# Place a GP prior over the function f.
395
415
sigma = pm.HalfCauchy("sigma", beta=3)
396
- y_ = gp.marginal_likelihood("y", X=X, y=y, noise =sigma)
416
+ y_ = gp.marginal_likelihood("y", X=X, y=y, sigma =sigma)
397
417
398
418
...
399
419
@@ -405,15 +425,15 @@ class Marginal(Base):
405
425
fcond = gp.conditional("fcond", Xnew=Xnew)
406
426
"""
407
427
408
- def _build_marginal_likelihood (self , X , noise , jitter ):
428
+ def _build_marginal_likelihood (self , X , noise_func , jitter ):
409
429
mu = self .mean_func (X )
410
430
Kxx = self .cov_func (X )
411
- Knx = noise (X )
431
+ Knx = noise_func (X )
412
432
cov = Kxx + Knx
413
433
return mu , stabilize (cov , jitter )
414
434
415
435
def marginal_likelihood (
416
- self , name , X , y , noise , jitter = JITTER_DEFAULT , is_observed = True , ** kwargs
436
+ self , name , X , y , sigma = None , noise = None , jitter = JITTER_DEFAULT , is_observed = True , ** kwargs
417
437
):
418
438
R"""
419
439
Returns the marginal likelihood distribution, given the input
@@ -435,23 +455,25 @@ def marginal_likelihood(
435
455
y: array-like
436
456
Data that is the sum of the function with the GP prior and Gaussian
437
457
noise. Must have shape `(n, )`.
438
- noise : scalar, Variable, or Covariance
458
+ sigma : scalar, Variable, or Covariance
439
459
Standard deviation of the Gaussian noise. Can also be a Covariance for
440
460
non-white noise.
461
+ noise: scalar, Variable, or Covariance
462
+ Previous parameterization of `sigma`.
441
463
jitter: scalar
442
464
A small correction added to the diagonal of positive semi-definite
443
465
covariance matrices to ensure numerical stability.
444
466
**kwargs
445
467
Extra keyword arguments that are passed to `MvNormal` distribution
446
468
constructor.
447
469
"""
470
+ sigma = _handle_sigma_noise_parameters (sigma = sigma , noise = noise )
448
471
449
- if not isinstance (noise , Covariance ):
450
- noise = pm .gp .cov .WhiteNoise (noise )
451
- mu , cov = self ._build_marginal_likelihood (X , noise , jitter )
472
+ noise_func = sigma if isinstance (sigma , Covariance ) else pm .gp .cov .WhiteNoise (sigma )
473
+ mu , cov = self ._build_marginal_likelihood (X = X , noise_func = noise_func , jitter = jitter )
452
474
self .X = X
453
475
self .y = y
454
- self .noise = noise
476
+ self .sigma = noise_func
455
477
if is_observed :
456
478
return pm .MvNormal (name , mu = mu , cov = cov , observed = y , ** kwargs )
457
479
else :
@@ -472,20 +494,24 @@ def _get_given_vals(self, given):
472
494
else :
473
495
cov_total = self .cov_func
474
496
mean_total = self .mean_func
475
- if all (val in given for val in ["X" , "y" , "noise" ]):
476
- X , y , noise = given ["X" ], given ["y" ], given ["noise" ]
477
- if not isinstance (noise , Covariance ):
478
- noise = pm .gp .cov .WhiteNoise (noise )
497
+
498
+ if "noise" in given :
499
+ warnings .warn (_noise_deprecation_warning , FutureWarning )
500
+ given ["sigma" ] = given ["noise" ]
501
+
502
+ if all (val in given for val in ["X" , "y" , "sigma" ]):
503
+ X , y , sigma = given ["X" ], given ["y" ], given ["sigma" ]
504
+ noise_func = sigma if isinstance (sigma , Covariance ) else pm .gp .cov .WhiteNoise (sigma )
479
505
else :
480
- X , y , noise = self .X , self .y , self .noise
481
- return X , y , noise , cov_total , mean_total
506
+ X , y , noise_func = self .X , self .y , self .sigma
507
+ return X , y , noise_func , cov_total , mean_total
482
508
483
509
def _build_conditional (
484
- self , Xnew , pred_noise , diag , X , y , noise , cov_total , mean_total , jitter
510
+ self , Xnew , pred_noise , diag , X , y , noise_func , cov_total , mean_total , jitter
485
511
):
486
512
Kxx = cov_total (X )
487
513
Kxs = self .cov_func (X , Xnew )
488
- Knx = noise (X )
514
+ Knx = noise_func (X )
489
515
rxx = y - mean_total (X )
490
516
L = cholesky (stabilize (Kxx , jitter ) + Knx )
491
517
A = solve_lower (L , Kxs )
@@ -495,13 +521,13 @@ def _build_conditional(
495
521
Kss = self .cov_func (Xnew , diag = True )
496
522
var = Kss - at .sum (at .square (A ), 0 )
497
523
if pred_noise :
498
- var += noise (Xnew , diag = True )
524
+ var += noise_func (Xnew , diag = True )
499
525
return mu , var
500
526
else :
501
527
Kss = self .cov_func (Xnew )
502
528
cov = Kss - at .dot (at .transpose (A ), A )
503
529
if pred_noise :
504
- cov += noise (Xnew )
530
+ cov += noise_func (Xnew )
505
531
return mu , cov if pred_noise else stabilize (cov , jitter )
506
532
507
533
def conditional (
@@ -531,7 +557,7 @@ def conditional(
531
557
Whether or not observation noise is included in the conditional.
532
558
Default is `False`.
533
559
given: dict
534
- Can optionally take as key value pairs: `X`, `y`, `noise `,
560
+ Can optionally take as key value pairs: `X`, `y`, `sigma `,
535
561
and `gp`. See the section in the documentation on additive GP
536
562
models in PyMC for more information.
537
563
jitter: scalar
@@ -720,7 +746,9 @@ def _build_marginal_likelihood_loglik(self, y, X, Xu, sigma, jitter):
720
746
quadratic = 0.5 * (at .dot (r , r_l ) - at .dot (c , c ))
721
747
return - 1.0 * (constant + logdet + quadratic + trace )
722
748
723
- def marginal_likelihood (self , name , X , Xu , y , noise = None , jitter = JITTER_DEFAULT , ** kwargs ):
749
+ def marginal_likelihood (
750
+ self , name , X , Xu , y , sigma = None , noise = None , jitter = JITTER_DEFAULT , ** kwargs
751
+ ):
724
752
R"""
725
753
Returns the approximate marginal likelihood distribution, given the input
726
754
locations `X`, inducing point locations `Xu`, data `y`, and white noise
@@ -738,8 +766,10 @@ def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT,
738
766
y: array-like
739
767
Data that is the sum of the function with the GP prior and Gaussian
740
768
noise. Must have shape `(n, )`.
741
- noise : scalar, Variable
769
+ sigma : scalar, Variable
742
770
Standard deviation of the Gaussian noise.
771
+ noise: scalar, Variable
772
+ Previous parameterization of `sigma`
743
773
jitter: scalar
744
774
A small correction added to the diagonal of positive semi-definite
745
775
covariance matrices to ensure numerical stability.
@@ -752,12 +782,11 @@ def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT,
752
782
self .Xu = Xu
753
783
self .y = y
754
784
755
- if noise is None :
756
- raise ValueError ("noise argument must be specified" )
757
- else :
758
- self .sigma = noise
785
+ self .sigma = _handle_sigma_noise_parameters (sigma = sigma , noise = noise )
759
786
760
- approx_loglik = self ._build_marginal_likelihood_loglik (y , X , Xu , noise , jitter )
787
+ approx_loglik = self ._build_marginal_likelihood_loglik (
788
+ y = self .y , X = self .X , Xu = self .Xu , sigma = self .sigma , jitter = jitter
789
+ )
761
790
pm .Potential (f"marginalapprox_loglik_{ name } " , approx_loglik , ** kwargs )
762
791
763
792
def _build_conditional (
@@ -828,7 +857,7 @@ def conditional(
828
857
Whether or not observation noise is included in the conditional.
829
858
Default is `False`.
830
859
given: dict
831
- Can optionally take as key value pairs: `X`, `Xu`, `y`, `noise `,
860
+ Can optionally take as key value pairs: `X`, `Xu`, `y`, `sigma `,
832
861
and `gp`. See the section in the documentation on additive GP
833
862
models in PyMC for more information.
834
863
jitter: scalar
0 commit comments