|
21 | 21 | from numpy import random as npr
|
22 | 22 | from pytensor import scan
|
23 | 23 | from pytensor import tensor as pt
|
| 24 | +from pytensor.graph import FunctionGraph |
24 | 25 | from scipy import stats as st
|
25 | 26 |
|
26 | 27 | from pymc.distributions import (
|
|
42 | 43 | Uniform,
|
43 | 44 | )
|
44 | 45 | from pymc.distributions.custom import CustomDist, CustomDistRV, CustomSymbolicDistRV
|
45 |
| -from pymc.distributions.distribution import support_point |
| 46 | +from pymc.distributions.distribution import inline_symbolic_random_variable, support_point |
46 | 47 | from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
|
47 | 48 | from pymc.distributions.transforms import log
|
48 | 49 | from pymc.exceptions import BlockModelAccessError
|
49 |
| -from pymc.logprob import logcdf, logp |
| 50 | +from pymc.logprob import conditional_logp, logcdf, logp |
50 | 51 | from pymc.model import Deterministic, Model
|
51 | 52 | from pymc.pytensorf import collect_default_updates
|
52 | 53 | from pymc.sampling import draw, sample, sample_posterior_predictive
|
@@ -648,3 +649,71 @@ def dist(p, size):
|
648 | 649 | assert out.owner.op.extended_signature == "[size],(),[rng]->(),[rng]"
|
649 | 650 | assert out.owner.op.ndim_supp == 0
|
650 | 651 | assert out.owner.op.ndims_params == [0]
|
| 652 | + |
| 653 | + def test_inline_does_not_duplicate_graph(self): |
| 654 | + mu = Normal.dist() |
| 655 | + x = CustomDist.dist(mu, dist=lambda mu, size: Normal.dist(mu, size=size)) |
| 656 | + |
| 657 | + fgraph = FunctionGraph(outputs=[x], clone=False) |
| 658 | + [inner_x, inner_rng_update] = inline_symbolic_random_variable.transform(fgraph, x.owner) |
| 659 | + assert inner_rng_update.owner.inputs[-2] is mu |
| 660 | + assert inner_x.owner.inputs[-2] is mu |
| 661 | + |
| 662 | + def test_chained_custom_dist_bug(self): |
| 663 | + """Regression test for issue reported in https://discourse.pymc.io/t/error-with-custom-distribution-after-using-scan/16255 |
| 664 | +
|
| 665 | + This bug was caused by a duplication of a Scan-based CustomSymbolicDist when inlining another CustomSymbolicDist that used it as an input. |
| 666 | + PyTensor failed to merge the two Scan graphs, causing a failure in the logp extraction. |
| 667 | + """ |
| 668 | + |
| 669 | + rng = np.random.default_rng(123) |
| 670 | + steps = 4 |
| 671 | + batch = 2 |
| 672 | + |
| 673 | + def scan_dist(seq, n_steps, size): |
| 674 | + def step(s): |
| 675 | + innov = Normal.dist() |
| 676 | + traffic = s + innov |
| 677 | + return traffic, {innov.owner.inputs[0]: innov.owner.outputs[0]} |
| 678 | + |
| 679 | + rv_seq, _ = pytensor.scan( |
| 680 | + fn=step, |
| 681 | + sequences=[seq], |
| 682 | + outputs_info=[None], |
| 683 | + n_steps=n_steps, |
| 684 | + strict=True, |
| 685 | + ) |
| 686 | + return rv_seq |
| 687 | + |
| 688 | + def normal_shifted(mu, size): |
| 689 | + return Normal.dist(mu=mu, size=size) - 1 |
| 690 | + |
| 691 | + seq = pt.matrix("seq", shape=(batch, steps)) |
| 692 | + latent_rv = CustomDist.dist( |
| 693 | + seq.T, |
| 694 | + steps, |
| 695 | + dist=scan_dist, |
| 696 | + shape=(steps, batch), |
| 697 | + ) |
| 698 | + latent_rv.name = "latent" |
| 699 | + |
| 700 | + observed_rv = CustomDist.dist( |
| 701 | + latent_rv, |
| 702 | + dist=normal_shifted, |
| 703 | + shape=(steps, batch), |
| 704 | + ) |
| 705 | + observed_rv.name = "observed" |
| 706 | + |
| 707 | + latent_vv = latent_rv.type() |
| 708 | + observed_vv = observed_rv.type() |
| 709 | + |
| 710 | + observed_logp = conditional_logp({latent_rv: latent_vv, observed_rv: observed_vv})[ |
| 711 | + observed_vv |
| 712 | + ] |
| 713 | + latent_vv_test = rng.standard_normal(size=(steps, batch)) |
| 714 | + observed_vv_test = rng.standard_normal(size=(steps, batch)) |
| 715 | + expected_logp = st.norm.logpdf(observed_vv_test + 1, loc=latent_vv_test) |
| 716 | + np.testing.assert_allclose( |
| 717 | + observed_logp.eval({latent_vv: latent_vv_test, observed_vv: observed_vv_test}), |
| 718 | + expected_logp, |
| 719 | + ) |
0 commit comments