16
16
import aesara .tensor as at
17
17
import numpy as np
18
18
import pytest
19
+ import scipy .stats as st
19
20
20
21
from arviz .data .inference_data import InferenceData
21
22
22
23
import pymc3 as pm
23
24
25
+ from pymc3 .aesaraf import floatX
24
26
from pymc3 .backends .base import MultiTrace
27
+ from pymc3 .smc .smc import SMC
25
28
from pymc3 .tests .helpers import SeededTest
26
29
27
30
@@ -64,10 +67,6 @@ def two_gaussians(x):
64
67
x = pm .Normal ("x" , 0 , 1 )
65
68
y = pm .Normal ("y" , x , 1 , observed = 0 )
66
69
67
- with pm .Model () as self .slow_model :
68
- x = pm .Normal ("x" , 0 , 1 )
69
- y = pm .Normal ("y" , x , 1 , observed = 100 )
70
-
71
70
def test_sample (self ):
72
71
with self .SMC_test :
73
72
mtrace = pm .sample_smc (draws = self .samples , return_inferencedata = False )
@@ -76,12 +75,43 @@ def test_sample(self):
76
75
mu1d = np .abs (x ).mean (axis = 0 )
77
76
np .testing .assert_allclose (self .muref , mu1d , rtol = 0.0 , atol = 0.03 )
78
77
79
- def test_discrete_continuous (self ):
80
- with pm .Model () as model :
81
- a = pm .Poisson ("a" , 5 )
82
- b = pm .HalfNormal ("b" , 10 )
83
- y = pm .Normal ("y" , a , b , observed = [1 , 2 , 3 , 4 ])
84
- trace = pm .sample_smc (draws = 10 )
78
+ def test_discrete_rounding_proposal (self ):
79
+ """
80
+ Test that discrete variable values are automatically rounded
81
+ in SMC logp functions
82
+ """
83
+
84
+ with pm .Model () as m :
85
+ z = pm .Bernoulli ("z" , p = 0.7 )
86
+ like = pm .Potential ("like" , z * 1.0 )
87
+
88
+ smc = SMC (model = m )
89
+ smc .initialize_population ()
90
+ smc .setup_kernel ()
91
+ smc .initialize_logp ()
92
+
93
+ assert smc .prior_logp_func (floatX (np .array ([- 0.51 ]))) == - np .inf
94
+ assert np .isclose (smc .prior_logp_func (floatX (np .array ([- 0.49 ]))), np .log (0.3 ))
95
+ assert np .isclose (smc .prior_logp_func (floatX (np .array ([0.49 ]))), np .log (0.3 ))
96
+ assert np .isclose (smc .prior_logp_func (floatX (np .array ([0.51 ]))), np .log (0.7 ))
97
+ assert smc .prior_logp_func (floatX (np .array ([1.51 ]))) == - np .inf
98
+
99
+ def test_unobserved_discrete (self ):
100
+ n = 10
101
+ rng = self .get_random_state ()
102
+
103
+ z_true = np .zeros (n , dtype = int )
104
+ z_true [int (n / 2 ) :] = 1
105
+ y = st .norm (np .array ([- 1 , 1 ])[z_true ], 0.25 ).rvs (random_state = rng )
106
+
107
+ with pm .Model () as m :
108
+ z = pm .Bernoulli ("z" , p = 0.5 , size = n )
109
+ mu = pm .math .switch (z , 1.0 , - 1.0 )
110
+ like = pm .Normal ("like" , mu = mu , sigma = 0.25 , observed = y )
111
+
112
+ trace = pm .sample_smc (chains = 1 , return_inferencedata = False )
113
+
114
+ assert np .all (np .median (trace ["z" ], axis = 0 ) == z_true )
85
115
86
116
def test_ml (self ):
87
117
data = np .repeat ([1 , 0 ], [50 , 50 ])
@@ -109,14 +139,6 @@ def test_start(self):
109
139
}
110
140
trace = pm .sample_smc (500 , chains = 1 , start = start )
111
141
112
- def test_slowdown_warning (self ):
113
- with aesara .config .change_flags (floatX = "float32" ):
114
- with pytest .warns (UserWarning , match = "SMC sampling may run slower due to" ):
115
- with pm .Model () as model :
116
- a = pm .Poisson ("a" , 5 )
117
- y = pm .Normal ("y" , a , 5 , observed = [1 , 2 , 3 , 4 ])
118
- trace = pm .sample_smc (draws = 100 , chains = 2 , cores = 1 )
119
-
120
142
@pytest .mark .parametrize ("chains" , (1 , 2 ))
121
143
def test_return_datatype (self , chains ):
122
144
draws = 10
0 commit comments