Skip to content

Commit b9fbfed

Browse files
OriolAbrilricardoV94
authored andcommitted
fix returned object when no vars to sample and extend=True
1 parent 4300be1 commit b9fbfed

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

pymc/sampling/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def sample_posterior_predictive(
825825
if return_inferencedata and not extend_inferencedata:
826826
return InferenceData()
827827
elif return_inferencedata and extend_inferencedata:
828-
return trace
828+
return trace if idata is None else idata
829829
return {}
830830

831831
vars_in_trace = get_vars_in_point_list(_trace, model)

tests/sampling/test_forward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,11 @@ def test_normal_scalar(self):
492492
ppc = pm.sample_posterior_predictive(trace, var_names=[], return_inferencedata=False)
493493
assert len(ppc) == 0
494494

495+
# test empty ppc with extend_inferencedata
496+
assert isinstance(trace, InferenceData)
497+
ppc = pm.sample_posterior_predictive(trace, var_names=[], extend_inferencedata=True)
498+
assert ppc is trace
499+
495500
# test keep_size parameter
496501
ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False)
497502
assert ppc["a"].shape == (nchains, ndraws)

0 commit comments

Comments
 (0)