Skip to content

Commit 95a0ef1

Browse files
Small changes to TestImputationMissingData tests
Co-authored-by: Michael Osthege <[email protected]>
1 parent 6790049 commit 95a0ef1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pymc/tests/test_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,8 +1154,8 @@ def test_missing_basic(self, missing_data):
11541154
assert not np.isnan(model.compile_logp()(test_point))
11551155

11561156
with model:
1157-
prior_trace = pm.sample_prior_predictive(return_inferencedata=False)
1158-
assert {"x", "y"} <= set(prior_trace.keys())
1157+
ipr = pm.sample_prior_predictive()
1158+
assert {"x", "y"} <= set(ipr.prior.keys())
11591159

11601160
def test_missing_with_predictors(self):
11611161
predictors = np.array([0.5, 1, 0.5, 2, 0.3])
@@ -1171,8 +1171,8 @@ def test_missing_with_predictors(self):
11711171
assert not np.isnan(model.compile_logp()(test_point))
11721172

11731173
with model:
1174-
prior_trace = pm.sample_prior_predictive(return_inferencedata=False)
1175-
assert {"x", "y"} <= set(prior_trace.keys())
1174+
ipr = pm.sample_prior_predictive()
1175+
assert {"x", "y"} <= set(ipr.prior.keys())
11761176

11771177
def test_missing_dual_observations(self):
11781178
with pm.Model() as model:
@@ -1191,7 +1191,7 @@ def test_missing_dual_observations(self):
11911191
# TODO: Assert something
11921192
with warnings.catch_warnings():
11931193
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
1194-
trace = pm.sample(chains=1, draws=50)
1194+
trace = pm.sample(chains=1, tune=5, draws=50)
11951195

11961196
def test_interval_missing_observations(self):
11971197
with pm.Model() as model:

0 commit comments

Comments
 (0)