Skip to content

Commit 44d3730

Browse files
committed
Reinstate test_normal_scalar_idata
1 parent ef7a2e6 commit 44d3730

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

pymc/tests/test_sampling.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,29 @@ def test_normal_scalar(self):
505505
)
506506
assert ppc["a"].shape == (nchains * ndraws, 5)
507507

508+
def test_normal_scalar_idata(self):
509+
nchains = 2
510+
ndraws = 500
511+
with pm.Model() as model:
512+
mu = pm.Normal("mu", 0.0, 1.0)
513+
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
514+
trace = pm.sample(
515+
draws=ndraws,
516+
chains=nchains,
517+
return_inferencedata=False,
518+
discard_tuned_samples=False,
519+
)
520+
521+
assert not isinstance(trace, InferenceData)
522+
523+
with model:
524+
# test keep_size parameter and idata input
525+
idata = pm.to_inference_data(trace)
526+
assert isinstance(idata, InferenceData)
527+
528+
ppc = pm.sample_posterior_predictive(idata, keep_size=True, return_inferencedata=False)
529+
assert ppc["a"].shape == (nchains, ndraws)
530+
508531
def test_normal_vector(self, caplog):
509532
with pm.Model() as model:
510533
mu = pm.Normal("mu", 0.0, 1.0)

0 commit comments

Comments
 (0)