Skip to content

Commit e636852

Browse files
ricardoV94michaelosthege
authored andcommitted
Refactor TestStepMethods and reduce number of draws
The number of draws was set too high to accommodate the worst / buggy samplers (see #5815)
1 parent 70a3e0b commit e636852

File tree

1 file changed

+58
-38
lines changed

1 file changed

+58
-38
lines changed

pymc/tests/test_step.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -78,46 +78,58 @@ def teardown_class(self):
7878
shutil.rmtree(self.temp_dir)
7979

8080
def check_stat(self, check, idata, name):
81-
if hasattr(idata, "warmup_posterior"):
82-
group = idata.warmup_posterior
83-
else:
84-
group = idata.posterior
81+
group = idata.posterior
8582
for (var, stat, value, bound) in check:
86-
s = stat(group[var].sel(chain=0, draw=slice(2000, None)), axis=0)
87-
close_to(s, value, bound)
83+
s = stat(group[var].sel(chain=0), axis=0)
84+
close_to(s, value, bound, name)
8885

89-
def test_step_continuous(self):
90-
start, model, (mu, C) = mv_simple()
91-
unc = np.diag(C) ** 0.5
92-
check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0))
93-
_, model_coarse, _ = mv_simple_coarse()
94-
with model:
95-
steps = (
96-
Slice(),
97-
HamiltonianMC(scaling=C, is_cov=True, blocked=False),
98-
NUTS(scaling=C, is_cov=True, blocked=False),
99-
Metropolis(S=C, proposal_dist=MultivariateNormalProposal, blocked=True),
100-
Slice(blocked=True),
101-
HamiltonianMC(scaling=C, is_cov=True),
102-
NUTS(scaling=C, is_cov=True),
103-
CompoundStep(
86+
@pytest.mark.parametrize(
87+
"step_fn, draws",
88+
[
89+
(lambda C, _: HamiltonianMC(scaling=C, is_cov=True, blocked=False), 1000),
90+
(lambda C, _: HamiltonianMC(scaling=C, is_cov=True), 1000),
91+
(lambda C, _: NUTS(scaling=C, is_cov=True, blocked=False), 1000),
92+
(lambda C, _: NUTS(scaling=C, is_cov=True), 1000),
93+
(
94+
lambda C, _: CompoundStep(
10495
[
10596
HamiltonianMC(scaling=C, is_cov=True),
10697
HamiltonianMC(scaling=C, is_cov=True, blocked=False),
10798
]
10899
),
109-
MLDA(
100+
1000,
101+
),
102+
# MLDA takes 1/2 of the total test time!
103+
(
104+
lambda C, model_coarse: MLDA(
110105
coarse_models=[model_coarse],
111106
base_S=C,
112107
base_proposal_dist=MultivariateNormalProposal,
113108
),
114-
)
115-
for step in steps:
109+
1000,
110+
),
111+
(lambda *_: Slice(), 2000),
112+
(lambda *_: Slice(blocked=True), 2000),
113+
(
114+
lambda C, _: Metropolis(
115+
S=C, proposal_dist=MultivariateNormalProposal, blocked=True
116+
),
117+
4000,
118+
),
119+
],
120+
ids=str,
121+
)
122+
def test_step_continuous(self, step_fn, draws):
123+
start, model, (mu, C) = mv_simple()
124+
unc = np.diag(C) ** 0.5
125+
check = (("x", np.mean, mu, unc / 10), ("x", np.std, unc, unc / 10))
126+
_, model_coarse, _ = mv_simple_coarse()
127+
with model:
128+
step = step_fn(C, model_coarse)
116129
idata = sample(
117-
0,
118-
tune=8000,
130+
tune=1000,
131+
draws=draws,
119132
chains=1,
120-
discard_tuned_samples=False,
121133
step=step,
122134
start=start,
123135
model=model,
@@ -126,30 +138,38 @@ def test_step_continuous(self):
126138
self.check_stat(check, idata, step.__class__.__name__)
127139

128140
def test_step_discrete(self):
129-
if aesara.config.floatX == "float32":
130-
return # Cannot use @skip because it only skips one iteration of the yield
131141
start, model, (mu, C) = mv_simple_discrete()
132142
unc = np.diag(C) ** 0.5
133143
check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0))
134144
with model:
135-
steps = (Metropolis(S=C, proposal_dist=MultivariateNormalProposal),)
136-
for step in steps:
145+
step = Metropolis(S=C, proposal_dist=MultivariateNormalProposal)
137146
idata = sample(
138-
20000, tune=0, step=step, start=start, model=model, random_seed=1, chains=1
147+
tune=1000,
148+
draws=2000,
149+
chains=1,
150+
step=step,
151+
start=start,
152+
model=model,
153+
random_seed=1,
139154
)
140155
self.check_stat(check, idata, step.__class__.__name__)
141156

142-
def test_step_categorical(self):
157+
@pytest.mark.parametrize("proposal", ["uniform", "proportional"])
158+
def test_step_categorical(self, proposal):
143159
start, model, (mu, C) = simple_categorical()
144160
unc = C**0.5
145161
check = (("x", np.mean, mu, unc / 10.0), ("x", np.std, unc, unc / 10.0))
146162
with model:
147-
steps = (
148-
CategoricalGibbsMetropolis([model.x], proposal="uniform"),
149-
CategoricalGibbsMetropolis([model.x], proposal="proportional"),
163+
step = CategoricalGibbsMetropolis([model.x], proposal=proposal)
164+
idata = sample(
165+
tune=1000,
166+
draws=2000,
167+
chains=1,
168+
step=step,
169+
start=start,
170+
model=model,
171+
random_seed=1,
150172
)
151-
for step in steps:
152-
idata = sample(8000, tune=0, step=step, start=start, model=model, random_seed=1)
153173
self.check_stat(check, idata, step.__class__.__name__)
154174

155175

0 commit comments

Comments
 (0)