Skip to content

Commit 74b7788

Browse files
authored
Merge pull request #3841 from rpgoldman/iss3840
Fix computation of samples argument in sample_posterior_predictive Solves #3840
2 parents 363afc8 + 839206b commit 74b7788

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

pymc3/sampling.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1568,7 +1568,13 @@ def sample_posterior_predictive(
15681568
raise IncorrectArgumentsError("Should not specify both keep_size and size argukments")
15691569

15701570
if samples is None:
1571-
samples = sum(len(v) for v in trace._straces.values())
1571+
if isinstance(trace, MultiTrace):
1572+
samples = sum(len(v) for v in trace._straces.values())
1573+
elif isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
1574+
# this is a list of points
1575+
samples = len(trace)
1576+
else:
1577+
raise ValueError("Do not know how to compute number of samples for trace argument of type %s"%type(trace))
15721578

15731579
if samples < len_trace * nchain:
15741580
warnings.warn(

pymc3/tests/test_sampling.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from itertools import combinations
16+
from typing import Tuple
1617
import numpy as np
1718

1819
try:
@@ -693,6 +694,16 @@ def test_exec_nuts_init(method):
693694
assert "a" in start[0] and "b_log__" in start[0]
694695

695696

697+
@pytest.fixture(scope="class")
698+
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
699+
with pm.Model() as pmodel:
700+
n = pm.Normal('n')
701+
trace = pm.sample()
702+
703+
with pmodel:
704+
d = pm.Deterministic('d', n * 4)
705+
return pmodel, trace
706+
696707
class TestSamplePriorPredictive(SeededTest):
697708
def test_ignores_observed(self):
698709
observed = np.random.normal(10, 1, size=200)
@@ -851,3 +862,21 @@ def test_bounded_dist(self):
851862
with model:
852863
prior_trace = pm.sample_prior_predictive(5)
853864
assert prior_trace["x"].shape == (5, 3, 1)
865+
866+
class TestSamplePosteriorPredictive:
867+
def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture):
868+
pmodel, trace = point_list_arg_bug_fixture
869+
with pmodel:
870+
pp = pm.fast_sample_posterior_predictive(
871+
[trace[15]],
872+
var_names=['d']
873+
)
874+
875+
def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
876+
pmodel, trace = point_list_arg_bug_fixture
877+
with pmodel:
878+
pp = pm.sample_posterior_predictive(
879+
[trace[15]],
880+
var_names=['d']
881+
)
882+

0 commit comments

Comments
 (0)