Skip to content

Commit 7840581

Browse files
committed
Fix rvs_to_value_vars inplace update bug
1 parent 44cf8a7 commit 7840581

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

pymc/aesaraf.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def rvs_to_value_vars(
335335
initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
336336
**kwargs,
337337
) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]:
338-
"""Replace random variables in graphs with their value variables.
338+
"""Clone and replace random variables in graphs with their value variables.
339339
340340
This will *not* recompute test values in the resulting graphs.
341341
@@ -383,6 +383,16 @@ def transform_replacements(var, replacements):
383383
# Walk the transformed variable and make replacements
384384
return [trans_rv_value]
385385

386+
# Clone original graphs
387+
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
388+
equiv = clone_get_equiv(inputs, graphs, False, False, {})
389+
graphs = [equiv[n] for n in graphs]
390+
391+
if initial_replacements:
392+
initial_replacements = {
393+
equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items()
394+
}
395+
386396
return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs)
387397

388398

pymc/tests/test_aesaraf.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pytest
2424
import scipy.sparse as sps
2525

26-
from aesara.graph.basic import Constant, Variable, ancestors
26+
from aesara.graph.basic import Constant, Variable, ancestors, equal_computations
2727
from aesara.tensor.random.basic import normal, uniform
2828
from aesara.tensor.random.op import RandomVariable
2929
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
@@ -529,3 +529,24 @@ def test_rvs_to_value_vars():
529529
assert a_value_var in res_ancestors
530530
assert b_value_var in res_ancestors
531531
assert c_value_var in res_ancestors
532+
533+
534+
def test_rvs_to_value_vars_nested():
535+
# Test that calling rvs_to_value_vars in models with nested transformations
536+
# does not change the original rvs in place. See issue #5172
537+
with pm.Model() as m:
538+
one = pm.LogNormal("one", mu=0)
539+
two = pm.LogNormal("two", mu=at.log(one))
540+
541+
# We add potentials or deterministics that are not in topological order
542+
pm.Potential("two_pot", two)
543+
pm.Potential("one_pot", one)
544+
545+
before = aesara.clone_replace(m.free_RVs)
546+
547+
# This call would change the model free_RVs in place in #5172
548+
res, _ = rvs_to_value_vars(m.potentials, apply_transforms=True)
549+
550+
after = aesara.clone_replace(m.free_RVs)
551+
552+
assert equal_computations(before, after)

0 commit comments

Comments
 (0)