diff --git a/pymc_experimental/model_transform/conditioning.py b/pymc_experimental/model_transform/conditioning.py index 49de9f8e5..c0b5a7b74 100644 --- a/pymc_experimental/model_transform/conditioning.py +++ b/pymc_experimental/model_transform/conditioning.py @@ -1,9 +1,11 @@ +import warnings from typing import Any, List, Mapping, Optional, Sequence, Union from pymc import Model from pymc.logprob.transforms import RVTransform from pymc.pytensorf import _replace_vars_in_graphs from pymc.util import get_transformed_name, get_untransformed_name +from pytensor.graph import ancestors from pytensor.tensor import TensorVariable from pymc_experimental.model_transform.basic import ( @@ -199,7 +201,15 @@ def do( # Just a sanity check assert model_var in fgraph.variables - intervention.name = model_var.name + # If the intervention references the original variable we must give it a different name + if model_var in ancestors([intervention]): + intervention.name = f"do_{model_var.name}" + warnings.warn( + f"Intervention expression references the variable that is being intervened: {model_var.name}. " + f"Intervention will be given the name: {intervention.name}" + ) + else: + intervention.name = model_var.name dims = extract_dims(model_var) # If there are any RVs in the graph we introduce the intervention as a deterministic if rvs_in_graph([intervention]): diff --git a/pymc_experimental/tests/model_transform/test_conditioning.py b/pymc_experimental/tests/model_transform/test_conditioning.py index 6e3261774..6fcc82405 100644 --- a/pymc_experimental/tests/model_transform/test_conditioning.py +++ b/pymc_experimental/tests/model_transform/test_conditioning.py @@ -222,6 +222,23 @@ def test_do_prune(prune): assert set(do_m.named_vars) == orig_named_vars +def test_do_self_reference(): + """Check we can replace a variable by an expression that refers to the same variable.""" + with pm.Model() as m: + x = pm.Normal("x", 0, 1) + + with pytest.warns( + UserWarning, + match="Intervention expression references the variable that is being intervened", + ): + new_m = do(m, {x: x + 100}) + + x = new_m["x"] + do_x = new_m["do_x"] + draw_x, draw_do_x = pm.draw([x, do_x], draws=5) + np.testing.assert_allclose(draw_x + 100, draw_do_x) + + def test_change_value_transforms(): with pm.Model() as base_m: p = pm.Uniform("p", 0, 1, transform=None)