Skip to content

Commit 826c602

Browse files
committed
Fix tests idata conversion
1 parent fbda63e commit 826c602

File tree

1 file changed

+28
-33
lines changed

1 file changed

+28
-33
lines changed

pymc/tests/test_idata_conversion.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@ def data(self, eight_schools_params, draws, chains):
6666

6767
def get_inference_data(self, data, eight_schools_params):
6868
with data.model:
69-
prior = pm.sample_prior_predictive()
70-
posterior_predictive = pm.sample_posterior_predictive(data.obj)
69+
prior = pm.sample_prior_predictive(return_inferencedata=False)
70+
posterior_predictive = pm.sample_posterior_predictive(
71+
data.obj, return_inferencedata=False
72+
)
7173

7274
return (
7375
to_inference_data(
@@ -85,8 +87,10 @@ def get_predictions_inference_data(
8587
self, data, eight_schools_params, inplace
8688
) -> Tuple[InferenceData, Dict[str, np.ndarray]]:
8789
with data.model:
88-
prior = pm.sample_prior_predictive()
89-
posterior_predictive = pm.sample_posterior_predictive(data.obj)
90+
prior = pm.sample_prior_predictive(return_inferencedata=False)
91+
posterior_predictive = pm.sample_posterior_predictive(
92+
data.obj, return_inferencedata=False
93+
)
9094

9195
idata = to_inference_data(
9296
trace=data.obj,
@@ -106,7 +110,9 @@ def make_predictions_inference_data(
106110
self, data, eight_schools_params
107111
) -> Tuple[InferenceData, Dict[str, np.ndarray]]:
108112
with data.model:
109-
posterior_predictive = pm.sample_posterior_predictive(data.obj)
113+
posterior_predictive = pm.sample_posterior_predictive(
114+
data.obj, return_inferencedata=False
115+
)
110116
idata = predictions_to_inference_data(
111117
posterior_predictive,
112118
posterior_trace=data.obj,
@@ -199,7 +205,9 @@ def test_predictions_to_idata_new(self, data, eight_schools_params):
199205

200206
def test_posterior_predictive_keep_size(self, data, chains, draws, eight_schools_params):
201207
with data.model:
202-
posterior_predictive = pm.sample_posterior_predictive(data.obj, keep_size=True)
208+
posterior_predictive = pm.sample_posterior_predictive(
209+
data.obj, keep_size=True, return_inferencedata=False
210+
)
203211
inference_data = to_inference_data(
204212
trace=data.obj,
205213
posterior_predictive=posterior_predictive,
@@ -214,7 +222,9 @@ def test_posterior_predictive_keep_size(self, data, chains, draws, eight_schools
214222

215223
def test_posterior_predictive_warning(self, data, eight_schools_params, caplog):
216224
with data.model:
217-
posterior_predictive = pm.sample_posterior_predictive(data.obj, 370)
225+
posterior_predictive = pm.sample_posterior_predictive(
226+
data.obj, 370, return_inferencedata=False
227+
)
218228
inference_data = to_inference_data(
219229
trace=data.obj,
220230
posterior_predictive=posterior_predictive,
@@ -375,10 +385,7 @@ def test_multiple_observed_rv_without_observations(self):
375385
with pm.Model():
376386
mu = pm.Normal("mu")
377387
x = pm.DensityDist( # pylint: disable=unused-variable
378-
"x",
379-
mu,
380-
logp=lambda value, mu: pm.Normal.logp(value, mu, 1),
381-
observed=0.1,
388+
"x", mu, logp=lambda value, mu: pm.Normal.logp(value, mu, 1), observed=0.1
382389
)
383390
inference_data = pm.sample(100, chains=2, return_inferencedata=True)
384391
test_dict = {
@@ -483,7 +490,9 @@ def test_predictions_constant_data(self):
483490
y = pm.Data("y", [1.0, 2.0])
484491
beta = pm.Normal("beta", 0, 1)
485492
obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable
486-
predictive_trace = pm.sample_posterior_predictive(inference_data)
493+
predictive_trace = pm.sample_posterior_predictive(
494+
inference_data, return_inferencedata=False
495+
)
487496
assert set(predictive_trace.keys()) == {"obs"}
488497
# this should be four chains of 100 samples
489498
# assert predictive_trace["obs"].shape == (400, 2)
@@ -506,8 +515,8 @@ def test_no_trace(self):
506515
beta = pm.Normal("beta", 0, 1)
507516
obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable
508517
idata = pm.sample(100, tune=100)
509-
prior = pm.sample_prior_predictive()
510-
posterior_predictive = pm.sample_posterior_predictive(idata)
518+
prior = pm.sample_prior_predictive(return_inferencedata=False)
519+
posterior_predictive = pm.sample_posterior_predictive(idata, return_inferencedata=False)
511520

512521
# Only prior
513522
inference_data = to_inference_data(prior=prior, model=model)
@@ -539,7 +548,7 @@ def test_priors_separation(self, use_context):
539548
y = pm.Data("y", [1.0, 2.0, 3.0])
540549
beta = pm.Normal("beta", 0, 1)
541550
obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable
542-
prior = pm.sample_prior_predictive()
551+
prior = pm.sample_prior_predictive(return_inferencedata=False)
543552

544553
test_dict = {
545554
"prior": ["beta", "~obs"],
@@ -574,10 +583,7 @@ def test_multivariate_observations(self):
574583

575584
def test_constant_data_coords_issue_5046(self):
576585
"""This is a regression test against a bug where a local coords variable was overwritten."""
577-
dims = {
578-
"alpha": ["backwards"],
579-
"bravo": ["letters", "yesno"],
580-
}
586+
dims = {"alpha": ["backwards"], "bravo": ["letters", "yesno"]}
581587
coords = {
582588
"backwards": np.arange(17)[::-1],
583589
"letters": list("ABCDEFGHIJK"),
@@ -592,20 +598,13 @@ def test_constant_data_coords_issue_5046(self):
592598
assert len(data[k].shape) == len(dims[k])
593599

594600
ds = pm.backends.arviz.dict_to_dataset(
595-
data=data,
596-
library=pm,
597-
coords=coords,
598-
dims=dims,
599-
default_dims=[],
600-
index_origin=0,
601+
data=data, library=pm, coords=coords, dims=dims, default_dims=[], index_origin=0
601602
)
602603
for dname, cvals in coords.items():
603604
np.testing.assert_array_equal(ds[dname].values, cvals)
604605

605606
def test_issue_5043_autoconvert_coord_values(self):
606-
coords = {
607-
"city": pd.Series(["Bonn", "Berlin"]),
608-
}
607+
coords = {"city": pd.Series(["Bonn", "Berlin"])}
609608
with pm.Model(coords=coords) as pmodel:
610609
# The model tracks coord values as (immutable) tuples
611610
assert isinstance(pmodel.coords["city"], tuple)
@@ -631,11 +630,7 @@ def test_issue_5043_autoconvert_coord_values(self):
631630
trace=mtrace,
632631
coords={
633632
"city": pd.MultiIndex.from_tuples(
634-
[
635-
("Bonn", 53111),
636-
("Berlin", 10178),
637-
],
638-
names=["name", "zipcode"],
633+
[("Bonn", 53111), ("Berlin", 10178)], names=["name", "zipcode"]
639634
)
640635
},
641636
)

0 commit comments

Comments
 (0)