Skip to content

Commit 53e0774

Browse files
committed
Fix bug with chained CustomSymbolicDists
1 parent 710c38b commit 53e0774

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

pymc/distributions/distribution.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from pytensor import tensor as pt
2929
from pytensor.compile.builders import OpFromGraph
30-
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
30+
from pytensor.graph import FunctionGraph, graph_replace, node_rewriter
3131
from pytensor.graph.basic import Apply, Variable
3232
from pytensor.graph.rewriting.basic import in2out
3333
from pytensor.graph.utils import MetaType
@@ -588,7 +588,9 @@ def inline_symbolic_random_variable(fgraph, node):
588588
"""Expand a SymbolicRV when obtaining the logp graph if `inline_logprob` is True."""
589589
op = node.op
590590
if op.inline_logprob:
591-
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
591+
return graph_replace(
592+
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)), strict=False
593+
)
592594

593595

594596
# Registered before pre-canonicalization which happens at position=-10

tests/distributions/test_custom.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from numpy import random as npr
2222
from pytensor import scan
2323
from pytensor import tensor as pt
24+
from pytensor.graph import FunctionGraph
2425
from scipy import stats as st
2526

2627
from pymc.distributions import (
@@ -42,11 +43,11 @@
4243
Uniform,
4344
)
4445
from pymc.distributions.custom import CustomDist, CustomDistRV, CustomSymbolicDistRV
45-
from pymc.distributions.distribution import support_point
46+
from pymc.distributions.distribution import inline_symbolic_random_variable, support_point
4647
from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
4748
from pymc.distributions.transforms import log
4849
from pymc.exceptions import BlockModelAccessError
49-
from pymc.logprob import logcdf, logp
50+
from pymc.logprob import conditional_logp, logcdf, logp
5051
from pymc.model import Deterministic, Model
5152
from pymc.pytensorf import collect_default_updates
5253
from pymc.sampling import draw, sample, sample_posterior_predictive
@@ -648,3 +649,71 @@ def dist(p, size):
648649
assert out.owner.op.extended_signature == "[size],(),[rng]->(),[rng]"
649650
assert out.owner.op.ndim_supp == 0
650651
assert out.owner.op.ndims_params == [0]
652+
653+
def test_inline_does_not_duplicate_graph(self):
654+
mu = Normal.dist()
655+
x = CustomDist.dist(mu, dist=lambda mu, size: Normal.dist(mu, size=size))
656+
657+
fgraph = FunctionGraph(outputs=[x], clone=False)
658+
[inner_x, inner_rng_update] = inline_symbolic_random_variable.transform(fgraph, x.owner)
659+
assert inner_rng_update.owner.inputs[-2] is mu
660+
assert inner_x.owner.inputs[-2] is mu
661+
662+
def test_chained_custom_dist_bug(self):
663+
"""Regression test for issue reported in https://discourse.pymc.io/t/error-with-custom-distribution-after-using-scan/16255
664+
665+
This bug was caused by a duplication of a Scan-based CustomSymbolicDist when inlining another CustomSymbolicDist that used it as an input.
666+
PyTensor failed to merge the two Scan graphs, causing a failure in the logp extraction.
667+
"""
668+
669+
rng = np.random.default_rng(123)
670+
steps = 4
671+
batch = 2
672+
673+
def scan_dist(seq, n_steps, size):
674+
def step(s):
675+
innov = Normal.dist()
676+
traffic = s + innov
677+
return traffic, {innov.owner.inputs[0]: innov.owner.outputs[0]}
678+
679+
rv_seq, _ = pytensor.scan(
680+
fn=step,
681+
sequences=[seq],
682+
outputs_info=[None],
683+
n_steps=n_steps,
684+
strict=True,
685+
)
686+
return rv_seq
687+
688+
def normal_shifted(mu, size):
689+
return Normal.dist(mu=mu, size=size) - 1
690+
691+
seq = pt.matrix("seq", shape=(batch, steps))
692+
latent_rv = CustomDist.dist(
693+
seq.T,
694+
steps,
695+
dist=scan_dist,
696+
shape=(steps, batch),
697+
)
698+
latent_rv.name = "latent"
699+
700+
observed_rv = CustomDist.dist(
701+
latent_rv,
702+
dist=normal_shifted,
703+
shape=(steps, batch),
704+
)
705+
observed_rv.name = "observed"
706+
707+
latent_vv = latent_rv.type()
708+
observed_vv = observed_rv.type()
709+
710+
observed_logp = conditional_logp({latent_rv: latent_vv, observed_rv: observed_vv})[
711+
observed_vv
712+
]
713+
latent_vv_test = rng.standard_normal(size=(steps, batch))
714+
observed_vv_test = rng.standard_normal(size=(steps, batch))
715+
expected_logp = st.norm.logpdf(observed_vv_test + 1, loc=latent_vv_test)
716+
np.testing.assert_allclose(
717+
observed_logp.eval({latent_vv: latent_vv_test, observed_vv: observed_vv_test}),
718+
expected_logp,
719+
)

0 commit comments

Comments
 (0)