Skip to content

Commit b566060

Browse files
committed
Test to replicate issue 3840.
1 parent 363afc8 commit b566060

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

pymc3/tests/test_sampling.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from itertools import combinations
16+
from typing import Tuple
1617
import numpy as np
1718

1819
try:
@@ -693,6 +694,16 @@ def test_exec_nuts_init(method):
693694
assert "a" in start[0] and "b_log__" in start[0]
694695

695696

697+
@pytest.fixture(scope="class")
698+
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
699+
with pm.Model() as pmodel:
700+
n = pm.Normal('n')
701+
trace = pm.sample()
702+
703+
with pmodel:
704+
d = pm.Deterministic('d', n * 4)
705+
return pmodel, trace
706+
696707
class TestSamplePriorPredictive(SeededTest):
697708
def test_ignores_observed(self):
698709
observed = np.random.normal(10, 1, size=200)
@@ -851,3 +862,21 @@ def test_bounded_dist(self):
851862
with model:
852863
prior_trace = pm.sample_prior_predictive(5)
853864
assert prior_trace["x"].shape == (5, 3, 1)
865+
866+
class TestSamplePosteriorPredictive:
867+
def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture):
868+
pmodel, trace = point_list_arg_bug_fixture
869+
with pmodel:
870+
pp = pm.fast_sample_posterior_predictive(
871+
[trace[15]],
872+
var_names=['d']
873+
)
874+
875+
def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
876+
pmodel, trace = point_list_arg_bug_fixture
877+
with pmodel:
878+
pp = pm.sample_posterior_predictive(
879+
[trace[15]],
880+
var_names=['d']
881+
)
882+

0 commit comments

Comments
 (0)