Skip to content

Commit a918d02

Browse files
committed
Refactor test_sample_init
1 parent a8af470 commit a918d02

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

tests/sampling/test_mcmc.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -148,32 +148,35 @@ def test_sample_does_not_rely_on_external_global_seeding(self):
148148
assert np.all(idata12["x"] != idata22["x"])
149149
assert np.all(idata13["x"] != idata23["x"])
150150

151-
def test_sample_init(self):
151+
@pytest.mark.parametrize(
152+
"init",
153+
(
154+
"advi",
155+
"advi_map",
156+
"map",
157+
"adapt_diag",
158+
"jitter+adapt_diag",
159+
"jitter+adapt_diag_grad",
160+
"adapt_full",
161+
"jitter+adapt_full",
162+
),
163+
)
164+
def test_sample_init(self, init):
152165
with self.model:
153-
for init in (
154-
"advi",
155-
"advi_map",
156-
"map",
157-
"adapt_diag",
158-
"jitter+adapt_diag",
159-
"jitter+adapt_diag_grad",
160-
"adapt_full",
161-
"jitter+adapt_full",
162-
):
163-
kwargs = {
164-
"init": init,
165-
"tune": 120,
166-
"n_init": 1000,
167-
"draws": 50,
168-
"random_seed": 20160911,
169-
}
170-
with warnings.catch_warnings(record=True) as rec:
171-
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
172-
if init.endswith("adapt_full"):
173-
with pytest.warns(UserWarning, match="experimental feature"):
174-
pm.sample(**kwargs)
175-
else:
176-
pm.sample(**kwargs)
166+
kwargs = {
167+
"init": init,
168+
"tune": 120,
169+
"n_init": 1000,
170+
"draws": 50,
171+
"random_seed": 20160911,
172+
}
173+
with warnings.catch_warnings(record=True) as rec:
174+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
175+
if init.endswith("adapt_full"):
176+
with pytest.warns(UserWarning, match="experimental feature"):
177+
pm.sample(**kwargs, cores=1)
178+
else:
179+
pm.sample(**kwargs, cores=1)
177180

178181
def test_sample_args(self):
179182
with self.model:

0 commit comments

Comments
 (0)