Skip to content

Commit 336d9a6

Browse files
committed
Avoid unnecessary slow compilation of Aesara functions in test_step
1 parent 00f15a7 commit 336d9a6

File tree

1 file changed

+50
-41
lines changed

1 file changed

+50
-41
lines changed

pymc/tests/test_step.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
)
6161
from pymc.step_methods.mlda import extract_Q_estimate
6262
from pymc.tests.checks import close_to
63+
from pymc.tests.helpers import fast_unstable_sampling_mode
6364
from pymc.tests.models import (
6465
mv_simple,
6566
mv_simple_coarse,
@@ -175,20 +176,21 @@ def test_step_categorical(self, proposal):
175176

176177
class TestMetropolisProposal:
177178
def test_proposal_choice(self):
178-
_, model, _ = mv_simple()
179-
with model:
180-
initial_point = model.initial_point()
181-
initial_point_size = sum(initial_point[n.name].size for n in model.value_vars)
179+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
180+
_, model, _ = mv_simple()
181+
with model:
182+
initial_point = model.initial_point()
183+
initial_point_size = sum(initial_point[n.name].size for n in model.value_vars)
182184

183-
s = np.ones(initial_point_size)
184-
sampler = Metropolis(S=s)
185-
assert isinstance(sampler.proposal_dist, NormalProposal)
186-
s = np.diag(s)
187-
sampler = Metropolis(S=s)
188-
assert isinstance(sampler.proposal_dist, MultivariateNormalProposal)
189-
s[0, 0] = -s[0, 0]
190-
with pytest.raises(np.linalg.LinAlgError):
185+
s = np.ones(initial_point_size)
191186
sampler = Metropolis(S=s)
187+
assert isinstance(sampler.proposal_dist, NormalProposal)
188+
s = np.diag(s)
189+
sampler = Metropolis(S=s)
190+
assert isinstance(sampler.proposal_dist, MultivariateNormalProposal)
191+
s[0, 0] = -s[0, 0]
192+
with pytest.raises(np.linalg.LinAlgError):
193+
sampler = Metropolis(S=s)
192194

193195
def test_mv_proposal(self):
194196
np.random.seed(42)
@@ -202,59 +204,60 @@ def test_mv_proposal(self):
202204
class TestCompoundStep:
203205
samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis)
204206

205-
@pytest.mark.skipif(
206-
aesara.config.floatX == "float32", reason="Test fails on 32 bit due to linalg issues"
207-
)
208207
def test_non_blocked(self):
209208
"""Test that samplers correctly create non-blocked compound steps."""
210-
_, model = simple_2model_continuous()
211-
with model:
212-
for sampler in self.samplers:
213-
assert isinstance(sampler(blocked=False), CompoundStep)
209+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
210+
_, model = simple_2model_continuous()
211+
with model:
212+
for sampler in self.samplers:
213+
assert isinstance(sampler(blocked=False), CompoundStep)
214214

215-
@pytest.mark.skipif(
216-
aesara.config.floatX == "float32", reason="Test fails on 32 bit due to linalg issues"
217-
)
218215
def test_blocked(self):
219-
_, model = simple_2model_continuous()
220-
with model:
221-
for sampler in self.samplers:
222-
sampler_instance = sampler(blocked=True)
223-
assert not isinstance(sampler_instance, CompoundStep)
224-
assert isinstance(sampler_instance, sampler)
216+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
217+
_, model = simple_2model_continuous()
218+
with model:
219+
for sampler in self.samplers:
220+
sampler_instance = sampler(blocked=True)
221+
assert not isinstance(sampler_instance, CompoundStep)
222+
assert isinstance(sampler_instance, sampler)
225223

226224

227225
class TestAssignStepMethods:
228226
def test_bernoulli(self):
229227
"""Test bernoulli distribution is assigned binary gibbs metropolis method"""
230228
with Model() as model:
231229
Bernoulli("x", 0.5)
232-
steps = assign_step_methods(model, [])
230+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
231+
steps = assign_step_methods(model, [])
233232
assert isinstance(steps, BinaryGibbsMetropolis)
234233

235234
def test_normal(self):
236235
"""Test normal distribution is assigned NUTS method"""
237236
with Model() as model:
238237
Normal("x", 0, 1)
239-
steps = assign_step_methods(model, [])
238+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
239+
steps = assign_step_methods(model, [])
240240
assert isinstance(steps, NUTS)
241241

242242
def test_categorical(self):
243243
"""Test categorical distribution is assigned categorical gibbs metropolis method"""
244244
with Model() as model:
245245
Categorical("x", np.array([0.25, 0.75]))
246-
steps = assign_step_methods(model, [])
246+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
247+
steps = assign_step_methods(model, [])
247248
assert isinstance(steps, BinaryGibbsMetropolis)
248249
with Model() as model:
249250
Categorical("y", np.array([0.25, 0.70, 0.05]))
250-
steps = assign_step_methods(model, [])
251+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
252+
steps = assign_step_methods(model, [])
251253
assert isinstance(steps, CategoricalGibbsMetropolis)
252254

253255
def test_binomial(self):
254256
"""Test binomial distribution is assigned metropolis method."""
255257
with Model() as model:
256258
Binomial("x", 10, 0.5)
257-
steps = assign_step_methods(model, [])
259+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
260+
steps = assign_step_methods(model, [])
258261
assert isinstance(steps, Metropolis)
259262

260263
def test_normal_nograd_op(self):
@@ -274,7 +277,8 @@ def kill_grad(x):
274277
data = np.random.normal(size=(100,))
275278
Normal("y", mu=kill_grad(x), sigma=1, observed=data.astype(aesara.config.floatX))
276279

277-
steps = assign_step_methods(model, [])
280+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
281+
steps = assign_step_methods(model, [])
278282
assert isinstance(steps, Slice)
279283

280284
def test_modify_step_methods(self):
@@ -286,15 +290,17 @@ def test_modify_step_methods(self):
286290

287291
with Model() as model:
288292
Normal("x", 0, 1)
289-
steps = assign_step_methods(model, [])
293+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
294+
steps = assign_step_methods(model, [])
290295
assert not isinstance(steps, NUTS)
291296

292297
# add back nuts
293298
pm.STEP_METHODS = step_methods + [NUTS]
294299

295300
with Model() as model:
296301
Normal("x", 0, 1)
297-
steps = assign_step_methods(model, [])
302+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
303+
steps = assign_step_methods(model, [])
298304
assert isinstance(steps, NUTS)
299305

300306

@@ -1326,7 +1332,8 @@ def test_continuous_steps(self, step, step_kwargs):
13261332
c1 = HalfNormal("c1")
13271333
c2 = HalfNormal("c2")
13281334

1329-
assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars
1335+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
1336+
assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars
13301337
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(
13311338
step([c1, c2], **step_kwargs).vars
13321339
)
@@ -1343,7 +1350,8 @@ def test_discrete_steps(self, step, step_kwargs):
13431350
d1 = Bernoulli("d1", p=0.5)
13441351
d2 = Bernoulli("d2", p=0.5)
13451352

1346-
assert [m.rvs_to_values[d1]] == step([d1], **step_kwargs).vars
1353+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
1354+
assert [m.rvs_to_values[d1]] == step([d1], **step_kwargs).vars
13471355
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(
13481356
step([d1, d2], **step_kwargs).vars
13491357
)
@@ -1353,7 +1361,8 @@ def test_compound_step(self):
13531361
c1 = HalfNormal("c1")
13541362
c2 = HalfNormal("c2")
13551363

1356-
step1 = NUTS([c1])
1357-
step2 = NUTS([c2])
1358-
step = CompoundStep([step1, step2])
1364+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
1365+
step1 = NUTS([c1])
1366+
step2 = NUTS([c2])
1367+
step = CompoundStep([step1, step2])
13591368
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(step.vars)

0 commit comments

Comments
 (0)