25
25
import pymc3 as pm
26
26
import pymc3 .parallel_sampling as ps
27
27
28
+ from pymc3 .aesaraf import floatX
29
+
28
30
29
31
def test_context ():
30
32
with pm .Model ():
@@ -83,15 +85,13 @@ def test_remote_pipe_closed():
83
85
pm .sample (step = step , mp_ctx = "spawn" , tune = 2 , draws = 2 , cores = 2 , chains = 2 )
84
86
85
87
86
- @pytest .mark .xfail (
87
- reason = "Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
88
- )
88
+ @pytest .mark .xfail (reason = "Unclear" )
89
89
def test_abort ():
90
90
with pm .Model () as model :
91
91
a = pm .Normal ("a" , shape = 1 )
92
- pm .HalfNormal ("b" )
93
- step1 = pm .NUTS ([a ])
94
- step2 = pm .Metropolis ([model [ "b_log__" ]])
92
+ b = pm .HalfNormal ("b" )
93
+ step1 = pm .NUTS ([model . rvs_to_values [ a ] ])
94
+ step2 = pm .Metropolis ([model . rvs_to_values [ b ]])
95
95
96
96
step = pm .CompoundStep ([step1 , step2 ])
97
97
@@ -104,7 +104,7 @@ def test_abort():
104
104
chain = 3 ,
105
105
seed = 1 ,
106
106
mp_ctx = ctx ,
107
- start = {"a" : np .array ([1.0 ]), "b_log__" : np .array (2.0 )},
107
+ start = {"a" : floatX ( np .array ([1.0 ])) , "b_log__" : floatX ( np .array (2.0 ) )},
108
108
step_method_pickled = None ,
109
109
)
110
110
proc .start ()
@@ -118,15 +118,12 @@ def test_abort():
118
118
proc .join ()
119
119
120
120
121
- @pytest .mark .xfail (
122
- reason = "Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
123
- )
124
121
def test_explicit_sample ():
125
122
with pm .Model () as model :
126
123
a = pm .Normal ("a" , shape = 1 )
127
- pm .HalfNormal ("b" )
128
- step1 = pm .NUTS ([a ])
129
- step2 = pm .Metropolis ([model [ "b_log__" ]])
124
+ b = pm .HalfNormal ("b" )
125
+ step1 = pm .NUTS ([model . rvs_to_values [ a ] ])
126
+ step2 = pm .Metropolis ([model . rvs_to_values [ b ]])
130
127
131
128
step = pm .CompoundStep ([step1 , step2 ])
132
129
@@ -138,7 +135,7 @@ def test_explicit_sample():
138
135
chain = 3 ,
139
136
seed = 1 ,
140
137
mp_ctx = ctx ,
141
- start = {"a" : np .array ([1.0 ]), "b_log__" : np .array (2.0 )},
138
+ start = {"a" : floatX ( np .array ([1.0 ])) , "b_log__" : floatX ( np .array (2.0 ) )},
142
139
step_method_pickled = None ,
143
140
)
144
141
proc .start ()
@@ -153,19 +150,16 @@ def test_explicit_sample():
153
150
proc .join ()
154
151
155
152
156
- @pytest .mark .xfail (
157
- reason = "Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
158
- )
159
153
def test_iterator ():
160
154
with pm .Model () as model :
161
155
a = pm .Normal ("a" , shape = 1 )
162
- pm .HalfNormal ("b" )
163
- step1 = pm .NUTS ([a ])
164
- step2 = pm .Metropolis ([model [ "b_log__" ]])
156
+ b = pm .HalfNormal ("b" )
157
+ step1 = pm .NUTS ([model . rvs_to_values [ a ] ])
158
+ step2 = pm .Metropolis ([model . rvs_to_values [ b ]])
165
159
166
160
step = pm .CompoundStep ([step1 , step2 ])
167
161
168
- start = {"a" : np .array ([1.0 ]), "b_log__" : np .array (2.0 )}
162
+ start = {"a" : floatX ( np .array ([1.0 ])) , "b_log__" : floatX ( np .array (2.0 ) )}
169
163
sampler = ps .ParallelSampler (10 , 10 , 3 , 2 , [2 , 3 , 4 ], [start ] * 3 , step , 0 , False )
170
164
with sampler :
171
165
for draw in sampler :
0 commit comments