File tree Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -825,7 +825,7 @@ def sample_posterior_predictive(
825
825
if return_inferencedata and not extend_inferencedata :
826
826
return InferenceData ()
827
827
elif return_inferencedata and extend_inferencedata :
828
- return trace
828
+ return trace if idata is None else idata
829
829
return {}
830
830
831
831
vars_in_trace = get_vars_in_point_list (_trace , model )
Original file line number Diff line number Diff line change @@ -492,6 +492,11 @@ def test_normal_scalar(self):
492
492
ppc = pm .sample_posterior_predictive (trace , var_names = [], return_inferencedata = False )
493
493
assert len (ppc ) == 0
494
494
495
+ # test empty ppc with extend_inferencedata
496
+ assert isinstance (trace , InferenceData )
497
+ ppc = pm .sample_posterior_predictive (trace , var_names = [], extend_inferencedata = True )
498
+ assert ppc is trace
499
+
495
500
# test keep_size parameter
496
501
ppc = pm .sample_posterior_predictive (trace , return_inferencedata = False )
497
502
assert ppc ["a" ].shape == (nchains , ndraws )
You can’t perform that action at this time.
0 commit comments