60
60
)
61
61
from pymc .step_methods .mlda import extract_Q_estimate
62
62
from pymc .tests .checks import close_to
63
+ from pymc .tests .helpers import fast_unstable_sampling_mode
63
64
from pymc .tests .models import (
64
65
mv_simple ,
65
66
mv_simple_coarse ,
@@ -175,20 +176,21 @@ def test_step_categorical(self, proposal):
175
176
176
177
class TestMetropolisProposal :
177
178
def test_proposal_choice (self ):
178
- _ , model , _ = mv_simple ()
179
- with model :
180
- initial_point = model .initial_point ()
181
- initial_point_size = sum (initial_point [n .name ].size for n in model .value_vars )
179
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
180
+ _ , model , _ = mv_simple ()
181
+ with model :
182
+ initial_point = model .initial_point ()
183
+ initial_point_size = sum (initial_point [n .name ].size for n in model .value_vars )
182
184
183
- s = np .ones (initial_point_size )
184
- sampler = Metropolis (S = s )
185
- assert isinstance (sampler .proposal_dist , NormalProposal )
186
- s = np .diag (s )
187
- sampler = Metropolis (S = s )
188
- assert isinstance (sampler .proposal_dist , MultivariateNormalProposal )
189
- s [0 , 0 ] = - s [0 , 0 ]
190
- with pytest .raises (np .linalg .LinAlgError ):
185
+ s = np .ones (initial_point_size )
191
186
sampler = Metropolis (S = s )
187
+ assert isinstance (sampler .proposal_dist , NormalProposal )
188
+ s = np .diag (s )
189
+ sampler = Metropolis (S = s )
190
+ assert isinstance (sampler .proposal_dist , MultivariateNormalProposal )
191
+ s [0 , 0 ] = - s [0 , 0 ]
192
+ with pytest .raises (np .linalg .LinAlgError ):
193
+ sampler = Metropolis (S = s )
192
194
193
195
def test_mv_proposal (self ):
194
196
np .random .seed (42 )
@@ -202,59 +204,60 @@ def test_mv_proposal(self):
202
204
class TestCompoundStep :
203
205
samplers = (Metropolis , Slice , HamiltonianMC , NUTS , DEMetropolis )
204
206
205
- @pytest .mark .skipif (
206
- aesara .config .floatX == "float32" , reason = "Test fails on 32 bit due to linalg issues"
207
- )
208
207
def test_non_blocked (self ):
209
208
"""Test that samplers correctly create non-blocked compound steps."""
210
- _ , model = simple_2model_continuous ()
211
- with model :
212
- for sampler in self .samplers :
213
- assert isinstance (sampler (blocked = False ), CompoundStep )
209
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
210
+ _ , model = simple_2model_continuous ()
211
+ with model :
212
+ for sampler in self .samplers :
213
+ assert isinstance (sampler (blocked = False ), CompoundStep )
214
214
215
- @pytest .mark .skipif (
216
- aesara .config .floatX == "float32" , reason = "Test fails on 32 bit due to linalg issues"
217
- )
218
215
def test_blocked (self ):
219
- _ , model = simple_2model_continuous ()
220
- with model :
221
- for sampler in self .samplers :
222
- sampler_instance = sampler (blocked = True )
223
- assert not isinstance (sampler_instance , CompoundStep )
224
- assert isinstance (sampler_instance , sampler )
216
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
217
+ _ , model = simple_2model_continuous ()
218
+ with model :
219
+ for sampler in self .samplers :
220
+ sampler_instance = sampler (blocked = True )
221
+ assert not isinstance (sampler_instance , CompoundStep )
222
+ assert isinstance (sampler_instance , sampler )
225
223
226
224
227
225
class TestAssignStepMethods :
228
226
def test_bernoulli (self ):
229
227
"""Test bernoulli distribution is assigned binary gibbs metropolis method"""
230
228
with Model () as model :
231
229
Bernoulli ("x" , 0.5 )
232
- steps = assign_step_methods (model , [])
230
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
231
+ steps = assign_step_methods (model , [])
233
232
assert isinstance (steps , BinaryGibbsMetropolis )
234
233
235
234
def test_normal (self ):
236
235
"""Test normal distribution is assigned NUTS method"""
237
236
with Model () as model :
238
237
Normal ("x" , 0 , 1 )
239
- steps = assign_step_methods (model , [])
238
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
239
+ steps = assign_step_methods (model , [])
240
240
assert isinstance (steps , NUTS )
241
241
242
242
def test_categorical (self ):
243
243
"""Test categorical distribution is assigned categorical gibbs metropolis method"""
244
244
with Model () as model :
245
245
Categorical ("x" , np .array ([0.25 , 0.75 ]))
246
- steps = assign_step_methods (model , [])
246
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
247
+ steps = assign_step_methods (model , [])
247
248
assert isinstance (steps , BinaryGibbsMetropolis )
248
249
with Model () as model :
249
250
Categorical ("y" , np .array ([0.25 , 0.70 , 0.05 ]))
250
- steps = assign_step_methods (model , [])
251
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
252
+ steps = assign_step_methods (model , [])
251
253
assert isinstance (steps , CategoricalGibbsMetropolis )
252
254
253
255
def test_binomial (self ):
254
256
"""Test binomial distribution is assigned metropolis method."""
255
257
with Model () as model :
256
258
Binomial ("x" , 10 , 0.5 )
257
- steps = assign_step_methods (model , [])
259
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
260
+ steps = assign_step_methods (model , [])
258
261
assert isinstance (steps , Metropolis )
259
262
260
263
def test_normal_nograd_op (self ):
@@ -274,7 +277,8 @@ def kill_grad(x):
274
277
data = np .random .normal (size = (100 ,))
275
278
Normal ("y" , mu = kill_grad (x ), sigma = 1 , observed = data .astype (aesara .config .floatX ))
276
279
277
- steps = assign_step_methods (model , [])
280
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
281
+ steps = assign_step_methods (model , [])
278
282
assert isinstance (steps , Slice )
279
283
280
284
def test_modify_step_methods (self ):
@@ -286,15 +290,17 @@ def test_modify_step_methods(self):
286
290
287
291
with Model () as model :
288
292
Normal ("x" , 0 , 1 )
289
- steps = assign_step_methods (model , [])
293
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
294
+ steps = assign_step_methods (model , [])
290
295
assert not isinstance (steps , NUTS )
291
296
292
297
# add back nuts
293
298
pm .STEP_METHODS = step_methods + [NUTS ]
294
299
295
300
with Model () as model :
296
301
Normal ("x" , 0 , 1 )
297
- steps = assign_step_methods (model , [])
302
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
303
+ steps = assign_step_methods (model , [])
298
304
assert isinstance (steps , NUTS )
299
305
300
306
@@ -1326,7 +1332,8 @@ def test_continuous_steps(self, step, step_kwargs):
1326
1332
c1 = HalfNormal ("c1" )
1327
1333
c2 = HalfNormal ("c2" )
1328
1334
1329
- assert [m .rvs_to_values [c1 ]] == step ([c1 ], ** step_kwargs ).vars
1335
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
1336
+ assert [m .rvs_to_values [c1 ]] == step ([c1 ], ** step_kwargs ).vars
1330
1337
assert {m .rvs_to_values [c1 ], m .rvs_to_values [c2 ]} == set (
1331
1338
step ([c1 , c2 ], ** step_kwargs ).vars
1332
1339
)
@@ -1343,7 +1350,8 @@ def test_discrete_steps(self, step, step_kwargs):
1343
1350
d1 = Bernoulli ("d1" , p = 0.5 )
1344
1351
d2 = Bernoulli ("d2" , p = 0.5 )
1345
1352
1346
- assert [m .rvs_to_values [d1 ]] == step ([d1 ], ** step_kwargs ).vars
1353
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
1354
+ assert [m .rvs_to_values [d1 ]] == step ([d1 ], ** step_kwargs ).vars
1347
1355
assert {m .rvs_to_values [d1 ], m .rvs_to_values [d2 ]} == set (
1348
1356
step ([d1 , d2 ], ** step_kwargs ).vars
1349
1357
)
@@ -1353,7 +1361,8 @@ def test_compound_step(self):
1353
1361
c1 = HalfNormal ("c1" )
1354
1362
c2 = HalfNormal ("c2" )
1355
1363
1356
- step1 = NUTS ([c1 ])
1357
- step2 = NUTS ([c2 ])
1358
- step = CompoundStep ([step1 , step2 ])
1364
+ with aesara .config .change_flags (mode = fast_unstable_sampling_mode ):
1365
+ step1 = NUTS ([c1 ])
1366
+ step2 = NUTS ([c2 ])
1367
+ step = CompoundStep ([step1 , step2 ])
1359
1368
assert {m .rvs_to_values [c1 ], m .rvs_to_values [c2 ]} == set (step .vars )
0 commit comments