Skip to content

Commit 7f6abd4

Browse files
added test
1 parent b93588f commit 7f6abd4

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

pymc/tests/test_sampling.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,28 @@ def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]
10421042

10431043

10441044
class TestSamplePriorPredictive(SeededTest):
1045+
def test_idata_output(self):
1046+
"""This test controls that returned idata
1047+
contains all expected groups"""
1048+
1049+
with pm.Model() as model:
1050+
x = pm.MutableData("x", [1, 2, 3])
1051+
y = pm.MutableData("y", [1.1, 1.9, 3.1])
1052+
a = pm.Normal("a", mu=1, sigma=1)
1053+
b = pm.Normal("b", mu=0, sigma=1)
1054+
mu = pm.Deterministic("mu", var=a * x + b)
1055+
obs = pm.Normal("obs", mu=mu, sigma=1, observed=y)
1056+
idata = pm.sample_prior_predictive(samples=10)
1057+
1058+
test_dict = {
1059+
"prior": ["a", "b", "mu"],
1060+
"prior_predictive": ["obs"],
1061+
"observed_data": ["obs"],
1062+
"constant_data": ["x", "y"],
1063+
}
1064+
fails = check_multiple_attrs(test_dict, idata)
1065+
assert not fails
1066+
10451067
def test_ignores_observed(self):
10461068
observed = np.random.normal(10, 1, size=200)
10471069
with pm.Model():

0 commit comments

Comments
 (0)