Skip to content

Commit 9a8aa7b

Browse files
committed
Allow do interventions to reference intervened variable
1 parent 430c3c8 commit 9a8aa7b

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

pymc_experimental/model_transform/conditioning.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import warnings
12
from typing import Any, Dict, List, Sequence, Union
23

34
from pymc import Model
45
from pymc.pytensorf import _replace_vars_in_graphs
6+
from pytensor.graph import ancestors
57
from pytensor.tensor import TensorVariable
68

79
from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed
@@ -188,7 +190,15 @@ def do(
188190
# Just a sanity check
189191
assert model_var in fgraph.variables
190192

191-
intervention.name = model_var.name
193+
# If the intervention references the original variable we must give it a different name
194+
if model_var in ancestors([intervention]):
195+
intervention.name = f"do_{model_var.name}"
196+
warnings.warn(
197+
f"Intervention expression references the variable that is being intervened: {model_var.name}. "
198+
f"Intervention will be given the name: {intervention.name}"
199+
)
200+
else:
201+
intervention.name = model_var.name
192202
dims = extract_dims(model_var)
193203
# If there are any RVs in the graph we introduce the intervention as a deterministic
194204
if rvs_in_graph([intervention]):

pymc_experimental/tests/model_transform/test_conditioning.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,16 @@ def test_do_prune(prune):
214214
assert set(do_m.named_vars) == {"x1", "z", "llike"}
215215
else:
216216
assert set(do_m.named_vars) == orig_named_vars
217+
218+
219+
def test_do_self_reference():
220+
"""Check we can replace a variable by an expression that refers to the same variable."""
221+
with pm.Model() as m:
222+
x = pm.Normal("x", 0, 1)
223+
224+
with pytest.warns(
225+
UserWarning,
226+
match="Intervention expression references the variable that is being intervened",
227+
):
228+
new_x = do(m, {x: x + 100})["do_x"]
229+
assert pm.draw(new_x) > 50

0 commit comments

Comments
 (0)