Skip to content

Commit 791c53c

Browse files
authored
Use to_tuple function in pm.fast_sample_posterior_predictive (#4927)
Closes #4854
1 parent 10c914b commit 791c53c

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
+ Fix `LKJCorr.random` method to work with `pm.sample_prior_predictive`. (see [#4780](https://github.com/pymc-devs/pymc3/pull/4780)).
99
+ Enable documentation generation via ReadTheDocs for upcoming v3 releases. (see [#4805](https://github.com/pymc-devs/pymc3/pull/4805)).
1010
+ Remove `float128` dtype support (see [#4834](https://github.com/pymc-devs/pymc3/pull/4834)).
11+
+ Use `to_tuple` function in `pm.fast_sample_posterior_predictive` to pass shape assertions (see [#4927](https://github.com/pymc-devs/pymc3/pull/4927)).
1112

1213
### New Features
1314
+ Generalized BART, bounded distributions like Binomial and Poisson can now be used as likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675), [#4709](https://github.com/pymc-devs/pymc3/pull/4709) and

pymc3/distributions/posterior_predictive.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_fast_drawable,
2727
vectorized_ppc,
2828
)
29+
from pymc3.distributions.shape_utils import to_tuple
2930
from pymc3.exceptions import IncorrectArgumentsError
3031
from pymc3.model import (
3132
Model,
@@ -551,7 +552,7 @@ def random_sample(
551552
) -> np.ndarray:
552553
val = meth(point=point, size=size)
553554
try:
554-
assert val.shape == (size,) + shape, (
555+
assert val.shape == to_tuple(size) + to_tuple(shape), (
555556
"Sampling from random of %s yields wrong shape" % param
556557
)
557558
# error-quashing here is *extremely* ugly, but it seems to be what the logic in DensityDist wants.

pymc3/tests/test_posterior_predictive.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,26 @@ def test_build_TraceDict_point_list():
3737
assert len(dict) == 1
3838
assert len(dict["mu"]) == 1
3939
assert dict["mu"][0] == 0.0
40+
41+
42+
def test_fast_sample_posterior_predictive_shape_assertions():
43+
"""
44+
This test checks the shape assertions in pm.fast_sample_posterior_predictive.
45+
Originally reported - https://github.com/pymc-devs/pymc3/issues/4778
46+
"""
47+
with pm.Model():
48+
p = pm.Beta("p", 2, 2)
49+
trace = pm.sample(
50+
tune=30, draws=50, chains=1, return_inferencedata=True, compute_convergence_checks=False
51+
)
52+
53+
with pm.Model() as m_forward:
54+
p = pm.Beta("p", 2, 2)
55+
b2 = pm.Binomial("b2", n=1, p=p)
56+
b3 = pm.Binomial("b3", n=1, p=p * b2)
57+
58+
with m_forward:
59+
trace_forward = pm.fast_sample_posterior_predictive(trace, var_names=["p", "b2", "b3"])
60+
61+
for free_rv in trace_forward.values():
62+
assert free_rv.shape[0] == 50

0 commit comments

Comments
 (0)