|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | from itertools import combinations
|
| 16 | +from typing import Tuple |
16 | 17 | import numpy as np
|
17 | 18 |
|
18 | 19 | try:
|
@@ -693,6 +694,16 @@ def test_exec_nuts_init(method):
|
693 | 694 | assert "a" in start[0] and "b_log__" in start[0]
|
694 | 695 |
|
695 | 696 |
|
| 697 | +@pytest.fixture(scope="class") |
| 698 | +def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]: |
| 699 | + with pm.Model() as pmodel: |
| 700 | + n = pm.Normal('n') |
| 701 | + trace = pm.sample() |
| 702 | + |
| 703 | + with pmodel: |
| 704 | + d = pm.Deterministic('d', n * 4) |
| 705 | + return pmodel, trace |
| 706 | + |
696 | 707 | class TestSamplePriorPredictive(SeededTest):
|
697 | 708 | def test_ignores_observed(self):
|
698 | 709 | observed = np.random.normal(10, 1, size=200)
|
@@ -851,3 +862,21 @@ def test_bounded_dist(self):
|
851 | 862 | with model:
|
852 | 863 | prior_trace = pm.sample_prior_predictive(5)
|
853 | 864 | assert prior_trace["x"].shape == (5, 3, 1)
|
| 865 | + |
| 866 | +class TestSamplePosteriorPredictive: |
| 867 | + def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture): |
| 868 | + pmodel, trace = point_list_arg_bug_fixture |
| 869 | + with pmodel: |
| 870 | + pp = pm.fast_sample_posterior_predictive( |
| 871 | + [trace[15]], |
| 872 | + var_names=['d'] |
| 873 | + ) |
| 874 | + |
| 875 | + def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture): |
| 876 | + pmodel, trace = point_list_arg_bug_fixture |
| 877 | + with pmodel: |
| 878 | + pp = pm.sample_posterior_predictive( |
| 879 | + [trace[15]], |
| 880 | + var_names=['d'] |
| 881 | + ) |
| 882 | + |
0 commit comments