@@ -698,21 +698,39 @@ def setup_method(self):
698
698
mean_func = pm .gp .mean .Constant (0.5 )
699
699
gp = pm .gp .Latent (mean_func , cov_func )
700
700
f = gp .prior ("f" , X )
701
- p = gp .conditional ("p" , Xnew )
702
701
chol = np .linalg .cholesky (cov_func (X ).eval ())
703
702
y_rotated = np .linalg .solve (chol , y - mean_func (X ).eval ())
704
- self .logp = model .logp ({
705
- "p" : pnew ,
706
- 'f_rotated_' : y_rotated
707
- })
703
+ logp_params = { 'f_rotated_' : y_rotated }
704
+ self .logp_prior = model .logp (logp_params )
705
+ with model :
706
+ p = gp .conditional ("p" , Xnew )
707
+ logp_params ['p' ] = pnew
708
+ self .logp_coditional = model .logp (logp_params )
708
709
self .X = X
709
710
self .Xnew = Xnew
710
711
self .y = y
711
712
self .pnew = pnew
712
713
self .gp = gp
713
714
714
715
@pytest .mark .parametrize ('approx' , ['FITC' , 'DTC' ])
715
- def testApproximations (self , approx ):
716
+ def testPriorApproximations (self , approx ):
717
+ with pm .Model () as model :
718
+ cov_func = pm .gp .cov .ExpQuad (3 , [0.1 , 0.2 , 0.3 ])
719
+ mean_func = pm .gp .mean .Constant (0.5 )
720
+ gp = pm .gp .LatentSparse (mean_func , cov_func , approx = approx )
721
+ f = gp .prior ("f" , self .X , self .X )
722
+ chol = np .linalg .cholesky (cov_func (self .X ).eval ())
723
+ y_rotated = np .linalg .solve (chol , self .y - mean_func (self .X ).eval ())
724
+ model_params = {
725
+ "f_u_rotated_" : y_rotated ,
726
+ }
727
+ if approx == 'FITC' :
728
+ model_params ['f' ] = self .y # need to specify as well since f ~ Normal(f_, diag(Kff-Qff))
729
+ approx_logp = model .logp (model_params )
730
+ npt .assert_allclose (approx_logp , self .logp_prior , atol = 0 , rtol = 1e-2 )
731
+
732
+ @pytest .mark .parametrize ('approx' , ['FITC' , 'DTC' ])
733
+ def testConditionalApproximations (self , approx ):
716
734
with pm .Model () as model :
717
735
cov_func = pm .gp .cov .ExpQuad (3 , [0.1 , 0.2 , 0.3 ])
718
736
mean_func = pm .gp .mean .Constant (0.5 )
@@ -728,7 +746,7 @@ def testApproximations(self, approx):
728
746
if approx == 'FITC' :
729
747
model_params ['f' ] = self .y # need to specify as well since f ~ Normal(f_, diag(Kff-Qff))
730
748
approx_logp = model .logp (model_params )
731
- npt .assert_allclose (approx_logp , self .logp , atol = 0 , rtol = 1e -2 )
749
+ npt .assert_allclose (approx_logp , self .logp_coditional , atol = 0 , rtol = 5e -2 )
732
750
733
751
734
752
class TestMarginalVsMarginalSparse (object ):
0 commit comments