@@ -1042,6 +1042,28 @@ def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]
1042
1042
1043
1043
1044
1044
class TestSamplePriorPredictive (SeededTest ):
1045
+ def test_idata_output (self ):
1046
+ """This test controls that returned idata
1047
+ contains all expected groups"""
1048
+
1049
+ with pm .Model () as model :
1050
+ x = pm .MutableData ("x" , [1 , 2 , 3 ])
1051
+ y = pm .MutableData ("y" , [1.1 , 1.9 , 3.1 ])
1052
+ a = pm .Normal ("a" , mu = 1 , sigma = 1 )
1053
+ b = pm .Normal ("b" , mu = 0 , sigma = 1 )
1054
+ mu = pm .Deterministic ("mu" , var = a * x + b )
1055
+ obs = pm .Normal ("obs" , mu = mu , sigma = 1 , observed = y )
1056
+ idata = pm .sample_prior_predictive (samples = 10 )
1057
+
1058
+ test_dict = {
1059
+ "prior" : ["a" , "b" , "mu" ],
1060
+ "prior_predictive" : ["obs" ],
1061
+ "observed_data" : ["obs" ],
1062
+ "constant_data" : ["x" , "y" ],
1063
+ }
1064
+ fails = check_multiple_attrs (test_dict , idata )
1065
+ assert not fails
1066
+
1045
1067
def test_ignores_observed (self ):
1046
1068
observed = np .random .normal (10 , 1 , size = 200 )
1047
1069
with pm .Model ():
0 commit comments