Skip to content

Commit 44dc340

Browse files
committed
Fix bug in nested replacement of value vars
1 parent eb1d8d6 commit 44dc340

File tree

4 files changed

+72
-7
lines changed

4 files changed

+72
-7
lines changed

pymc/aesaraf.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def expand_replace(var):
236236
new_nodes.extend(replacement_fn(var, replacements))
237237
return new_nodes
238238

239+
# This iteration populates the replacements
239240
for var in walk_model(graphs, expand_fn=expand_replace, **kwargs):
240241
pass
241242

@@ -250,7 +251,15 @@ def expand_replace(var):
250251
clone=False,
251252
)
252253

253-
fg.replace_all(replacements.items(), import_missing=True)
254+
# replacements have to be done in reverse topological order so that nested
255+
# expressions get recursively replaced correctly
256+
toposort = fg.toposort()
257+
sorted_replacements = sorted(
258+
tuple(replacements.items()),
259+
key=lambda pair: toposort.index(pair[0].owner),
260+
reverse=True,
261+
)
262+
fg.replace_all(sorted_replacements, import_missing=True)
254263

255264
graphs = list(fg.outputs)
256265

pymc/tests/distributions/test_logprob.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,7 @@
5252
logp,
5353
)
5454
from pymc.model import Model, Potential
55-
from pymc.tests.helpers import select_by_precision
56-
57-
58-
def assert_no_rvs(var):
59-
assert not any(isinstance(v.owner.op, RandomVariable) for v in ancestors([var]) if v.owner)
60-
return var
55+
from pymc.tests.helpers import assert_no_rvs, select_by_precision
6156

6257

6358
def test_get_scaling():

pymc/tests/helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import numpy.random as nr
2525

2626
from aesara.gradient import verify_grad as at_verify_grad
27+
from aesara.graph import ancestors
2728
from aesara.graph.rewriting.basic import in2out
2829
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
30+
from aesara.tensor.random.op import RandomVariable
2931

3032
import pymc as pm
3133

@@ -218,3 +220,8 @@ def continuous_steps(self, step, step_kwargs):
218220
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(
219221
step([c1, c2], **step_kwargs).vars
220222
)
223+
224+
225+
def assert_no_rvs(var):
226+
assert not any(isinstance(v.owner.op, RandomVariable) for v in ancestors([var]) if v.owner)
227+
return var

pymc/tests/test_aesaraf.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from pymc.distributions.distribution import SymbolicRandomVariable
4848
from pymc.distributions.transforms import Interval
4949
from pymc.exceptions import NotConstantValueError
50+
from pymc.tests.helpers import assert_no_rvs
5051
from pymc.vartypes import int_types
5152

5253

@@ -632,3 +633,56 @@ def test_no_change_inplace(self):
632633
after = aesara.clone_replace(m.free_RVs)
633634

634635
assert equal_computations(before, after)
636+
637+
@pytest.mark.parametrize("reversed", (False, True))
638+
def test_interdependent_transformed_rvs(self, reversed):
639+
# Test that nested transformed variables, whose transformed values depend on other
640+
# RVs are properly replaced
641+
with pm.Model() as m:
642+
transform = pm.distributions.transforms.Interval(
643+
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
644+
)
645+
x = pm.Uniform("x", lower=0, upper=1, transform=transform)
646+
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
647+
z = pm.Uniform("z", lower=0, upper=y, transform=transform)
648+
w = pm.Uniform("w", lower=0, upper=z, transform=transform)
649+
650+
rvs = [x, y, z, w]
651+
if reversed:
652+
rvs = rvs[::-1]
653+
654+
transform_values = rvs_to_value_vars(rvs)
655+
656+
for transform_value in transform_values:
657+
assert_no_rvs(transform_value)
658+
659+
if reversed:
660+
transform_values = transform_values[::-1]
661+
transform_values_fn = m.compile_fn(transform_values, point_fn=False)
662+
663+
x_interval_test_value = np.random.rand()
664+
y_interval_test_value = np.random.rand()
665+
z_interval_test_value = np.random.rand()
666+
w_interval_test_value = np.random.rand()
667+
668+
# The 3 Nones correspond to unused rng, dtype and size arguments
669+
expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval()
670+
expected_y = transform.backward(
671+
y_interval_test_value, None, None, None, 0, expected_x
672+
).eval()
673+
expected_z = transform.backward(
674+
z_interval_test_value, None, None, None, 0, expected_y
675+
).eval()
676+
expected_w = transform.backward(
677+
w_interval_test_value, None, None, None, 0, expected_z
678+
).eval()
679+
680+
np.testing.assert_allclose(
681+
transform_values_fn(
682+
x_interval__=x_interval_test_value,
683+
y_interval__=y_interval_test_value,
684+
z_interval__=z_interval_test_value,
685+
w_interval__=w_interval_test_value,
686+
),
687+
[expected_x, expected_y, expected_z, expected_w],
688+
)

0 commit comments

Comments
 (0)