@@ -1154,8 +1154,8 @@ def test_missing_basic(self, missing_data):
1154
1154
assert not np .isnan (model .compile_logp ()(test_point ))
1155
1155
1156
1156
with model :
1157
- prior_trace = pm .sample_prior_predictive (return_inferencedata = False )
1158
- assert {"x" , "y" } <= set (prior_trace .keys ())
1157
+ ipr = pm .sample_prior_predictive ()
1158
+ assert {"x" , "y" } <= set (ipr . prior .keys ())
1159
1159
1160
1160
def test_missing_with_predictors (self ):
1161
1161
predictors = np .array ([0.5 , 1 , 0.5 , 2 , 0.3 ])
@@ -1171,8 +1171,8 @@ def test_missing_with_predictors(self):
1171
1171
assert not np .isnan (model .compile_logp ()(test_point ))
1172
1172
1173
1173
with model :
1174
- prior_trace = pm .sample_prior_predictive (return_inferencedata = False )
1175
- assert {"x" , "y" } <= set (prior_trace .keys ())
1174
+ ipr = pm .sample_prior_predictive ()
1175
+ assert {"x" , "y" } <= set (ipr . prior .keys ())
1176
1176
1177
1177
def test_missing_dual_observations (self ):
1178
1178
with pm .Model () as model :
@@ -1191,7 +1191,7 @@ def test_missing_dual_observations(self):
1191
1191
# TODO: Assert something
1192
1192
with warnings .catch_warnings ():
1193
1193
warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
1194
- trace = pm .sample (chains = 1 , draws = 50 )
1194
+ trace = pm .sample (chains = 1 , tune = 5 , draws = 50 )
1195
1195
1196
1196
def test_interval_missing_observations (self ):
1197
1197
with pm .Model () as model :
0 commit comments