Skip to content

Commit 9547e7a

Browse files
committed
add test
1 parent 9e03e4d commit 9547e7a

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

pymc3/tests/test_sampling.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -718,22 +718,24 @@ def test_sample_posterior_predictive_w(self):
718718
mu = pm.Normal("mu", mu=0, sigma=1)
719719
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
720720
trace_0 = pm.sample()
721+
idata_0 = az.from_pymc3(trace_0)
721722

722723
with pm.Model() as model_1:
723724
mu = pm.Normal("mu", mu=0, sigma=1, shape=len(data0))
724725
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
725726
trace_1 = pm.sample()
726-
727-
traces = [trace_0, trace_0]
728-
models = [model_0, model_0]
729-
ppc = pm.sample_posterior_predictive_w(traces, 100, models)
730-
assert ppc["y"].shape == (100, 500)
727+
idata_1 = az.from_pymc3(trace_1)
731728

732729
traces = [trace_0, trace_1]
730+
idatas = [idata_0, idata_1]
733731
models = [model_0, model_1]
732+
734733
ppc = pm.sample_posterior_predictive_w(traces, 100, models)
735734
assert ppc["y"].shape == (100, 500)
736735

736+
ppc = pm.sample_posterior_predictive_w(idatas, 100, models)
737+
assert ppc["y"].shape == (100, 500)
738+
737739

738740
@pytest.mark.parametrize(
739741
"method",

0 commit comments

Comments
 (0)