Skip to content

Commit 2e05854

Browse files
committed
Avoid inplace mutation in replace_rvs_by_values
This would happen when transforms reference other variables
1 parent 01ddcb8 commit 2e05854

File tree

4 files changed

+85
-17
lines changed

4 files changed

+85
-17
lines changed

pymc/logprob/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from pytensor import Variable
4545
from pytensor import tensor as pt
4646
from pytensor.graph import Apply, Op, node_rewriter
47-
from pytensor.graph.basic import walk
47+
from pytensor.graph.basic import Constant, clone_get_equiv, graph_inputs, walk
4848
from pytensor.graph.op import HasInnerGraph
4949
from pytensor.link.c.type import CType
5050
from pytensor.raise_op import CheckAndRaise
@@ -77,6 +77,18 @@ def replace_rvs_by_values(
7777
Mapping between the original graph RVs and respective value transforms
7878
"""
7979

80+
if rvs_to_transforms:
81+
# Conditional transforms like Interval can reference variables in the original RV graph
82+
# To avoid mutating the original graphs in place, we have to clone them
83+
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
84+
equiv = clone_get_equiv(inputs, graphs, False, False)
85+
86+
graphs = [equiv[g] for g in graphs]
87+
rvs_to_values = {equiv.get(rv, rv): value for rv, value in rvs_to_values.items()}
88+
rvs_to_transforms = {
89+
equiv.get(rv, rv): transform for rv, transform in rvs_to_transforms.items()
90+
}
91+
8092
replacements = {}
8193

8294
def populate_replacements(var):

pymc/pytensorf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,10 @@ def replace_vars_in_graphs(
212212
) -> List[Variable]:
213213
"""Replace variables in graphs.
214214
215-
Graphs are cloned and not modified in place.
215+
Graphs are cloned and not modified in place, unless the replacement expressions include variables from the original graphs.
216+
216217
"""
217-
# Clone graph and get equivalences
218+
# Clone graphs and get equivalences
218219
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
219220
equiv = {k: k for k in replacements.keys()}
220221
equiv = clone_get_equiv(inputs, graphs, False, False, equiv)
@@ -1064,7 +1065,7 @@ def as_symbolic_string(x, **kwargs):
10641065
def toposort_replace(
10651066
fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False
10661067
) -> None:
1067-
"""Replace multiple variables in topological order."""
1068+
"""Replace multiple variables in place in topological order."""
10681069
toposort = fgraph.toposort()
10691070
sorted_replacements = sorted(
10701071
replacements,

tests/logprob/test_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
import pymc as pm
4848

49-
from pymc import SymbolicRandomVariable
49+
from pymc import SymbolicRandomVariable, inputvars
5050
from pymc.distributions.transforms import Interval
5151
from pymc.logprob.abstract import MeasurableVariable
5252
from pymc.logprob.basic import logp
@@ -210,7 +210,7 @@ def test_no_change_inplace(self):
210210
after = pytensor.clone_replace(m.free_RVs)
211211
assert equal_computations(before, after)
212212

213-
@pytest.mark.parametrize("reversed", (False, True))
213+
@pytest.mark.parametrize("reversed", (False,))
214214
def test_interdependent_transformed_rvs(self, reversed):
215215
# Test that nested transformed variables, whose transformed values depend on other
216216
# RVs are properly replaced
@@ -219,9 +219,10 @@ def test_interdependent_transformed_rvs(self, reversed):
219219
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
220220
)
221221
x = pm.Uniform("x", lower=0, upper=1, transform=transform)
222-
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
222+
# Operation between the variables provides a regression test for #7054
223+
y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform)
223224
z = pm.Uniform("z", lower=0, upper=y, transform=transform)
224-
w = pm.Uniform("w", lower=0, upper=z, transform=transform)
225+
w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform)
225226

226227
rvs = [x, y, z, w]
227228
if reversed:
@@ -233,8 +234,9 @@ def test_interdependent_transformed_rvs(self, reversed):
233234
rvs_to_transforms=m.rvs_to_transforms,
234235
)
235236

236-
for transform_value in transform_values:
237-
assert_no_rvs(transform_value)
237+
assert_no_rvs(transform_values)
238+
# Test that we haven't introduced value variables in the random graph (issue #7054)
239+
assert not inputvars(rvs)
238240

239241
if reversed:
240242
transform_values = transform_values[::-1]
@@ -248,13 +250,13 @@ def test_interdependent_transformed_rvs(self, reversed):
248250
# The 3 Nones correspond to unused rng, dtype and size arguments
249251
expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval()
250252
expected_y = transform.backward(
251-
y_interval_test_value, None, None, None, 0, expected_x
253+
y_interval_test_value, None, None, None, 0, pt.exp(expected_x)
252254
).eval()
253255
expected_z = transform.backward(
254256
z_interval_test_value, None, None, None, 0, expected_y
255257
).eval()
256258
expected_w = transform.backward(
257-
w_interval_test_value, None, None, None, 0, expected_z
259+
w_interval_test_value, None, None, None, 0, pt.square(expected_z)
258260
).eval()
259261

260262
np.testing.assert_allclose(

tests/test_pytensorf.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import scipy.sparse as sps
2424

2525
from pytensor import scan, shared
26+
from pytensor.compile import UnusedInputError
2627
from pytensor.compile.builders import OpFromGraph
2728
from pytensor.graph.basic import Variable
2829
from pytensor.tensor.random.basic import normal, uniform
@@ -670,11 +671,63 @@ def test_replace_vars_in_graphs():
670671
inp = shared(0.0, name="inp")
671672
x = pm.Normal.dist(inp)
672673

673-
assert x.eval() < 50
674-
675-
new_inp = inp + 100
676-
677-
replacements = {x.owner.inputs[3]: new_inp}
674+
replacements = {inp: inp + 100}
678675
[new_x] = replace_vars_in_graphs([x], replacements=replacements)
679676

677+
assert x.eval() < 50
680678
assert new_x.eval() > 50
679+
680+
681+
def test_replace_vars_in_graphs_nested_reference():
682+
# Replace both `x` and `y`, where the replacement of y references `x`
683+
x = pm.HalfNormal.dist(1e-3, name="x")
684+
neg_x = -x
685+
y = pm.Uniform.dist(neg_x, x, name="y")
686+
x_value = x.clone()
687+
y_value = y.clone()
688+
replacements = {x: x_value, y: neg_x + y_value}
689+
[new_x, new_y] = replace_vars_in_graphs([x, y], replacements=replacements)
690+
assert new_x.eval({x_value: 100}) == 100
691+
assert new_y.eval({x_value: 100, y_value: 1}) == -99
692+
assert new_y.eval({neg_x: 100, y_value: 1}) == 101
693+
assert np.abs(x.eval()) < 1
694+
# Confirm the original `y` variable is changed in place
695+
# This is unavoidable if we want to respect the identity of the replacement variables
696+
# As when imputing `neg_x` and `x` while evaluating `new_y` above and below.
697+
assert np.abs(y.eval({x_value: 100})) > 1
698+
699+
# Only replace `y`, same replacement as before
700+
x = pm.HalfNormal.dist(1e-3, name="x")
701+
neg_x = -x
702+
y = pm.Uniform.dist(neg_x, x, name="y")
703+
y_value = y.clone()
704+
replacements = {y: neg_x + y_value}
705+
[new_y] = replace_vars_in_graphs([y], replacements=replacements)
706+
assert np.abs(new_y.eval({y_value: 0})) < 1
707+
# Confirm that `x` and `neg_x` are still in the graph of `new_y` and that we can impute either
708+
assert new_y.eval({x: 100, y_value: 1}) == -99
709+
assert new_y.eval({neg_x: 100, y_value: 1}) == 101
710+
assert np.abs(x.eval()) < 1
711+
# In this case the original `y` is not altered, because we did not replace `x`
712+
assert np.abs(y.eval()) < 1
713+
714+
# Replacement introduces equivalent but not identical operations
715+
x = pm.HalfNormal.dist(1e-3, name="x")
716+
neg_x = -x
717+
neg_x.name = "neg_x"
718+
y = pm.Uniform.dist(neg_x, x, name="y")
719+
x_value = x.clone()
720+
y_value = y.clone()
721+
# We clone neg_x!
722+
replacements = {x: x_value, y: neg_x.owner.clone().outputs[0] + y_value}
723+
[new_x, new_y] = replace_vars_in_graphs([x, y], replacements=replacements)
724+
assert new_x.eval({x_value: 100}) == 100
725+
assert new_y.eval({x_value: 100, y_value: 1}) == -99
726+
# This now fails because the original `neg_x` is not in the replaced graph!
727+
with pytest.raises(UnusedInputError, match="neg_x"):
728+
new_y.eval({neg_x: 100, y_value: 1})
729+
# We can retrieve the cloned variable by name
730+
assert new_y.eval({"neg_x": 100, y_value: 1}) == 101
731+
assert np.abs(x.eval()) < 1
732+
# Confirm the original `y` variable is not changed in place
733+
assert np.abs(y.eval()) < 1

0 commit comments

Comments
 (0)