Skip to content

Commit 98e13c9

Browse files
committed
Allow do interventions to reference intervened variable
1 parent 9999e0f commit 98e13c9

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

pymc_experimental/model_transform/conditioning.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import warnings
12
from typing import Any, List, Mapping, Optional, Sequence, Union
23

34
from pymc import Model
45
from pymc.logprob.transforms import RVTransform
56
from pymc.pytensorf import _replace_vars_in_graphs
67
from pymc.util import get_transformed_name, get_untransformed_name
8+
from pytensor.graph import ancestors
79
from pytensor.tensor import TensorVariable
810

911
from pymc_experimental.model_transform.basic import (
@@ -199,7 +201,15 @@ def do(
199201
# Just a sanity check
200202
assert model_var in fgraph.variables
201203

202-
intervention.name = model_var.name
204+
# If the intervention references the original variable we must give it a different name
205+
if model_var in ancestors([intervention]):
206+
intervention.name = f"do_{model_var.name}"
207+
warnings.warn(
208+
f"Intervention expression references the variable that is being intervened: {model_var.name}. "
209+
f"Intervention will be given the name: {intervention.name}"
210+
)
211+
else:
212+
intervention.name = model_var.name
203213
dims = extract_dims(model_var)
204214
# If there are any RVs in the graph we introduce the intervention as a deterministic
205215
if rvs_in_graph([intervention]):

pymc_experimental/tests/model_transform/test_conditioning.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,23 @@ def test_do_prune(prune):
222222
assert set(do_m.named_vars) == orig_named_vars
223223

224224

225+
def test_do_self_reference():
226+
"""Check we can replace a variable by an expression that refers to the same variable."""
227+
with pm.Model() as m:
228+
x = pm.Normal("x", 0, 1)
229+
230+
with pytest.warns(
231+
UserWarning,
232+
match="Intervention expression references the variable that is being intervened",
233+
):
234+
new_m = do(m, {x: x + 100})
235+
236+
x = new_m["x"]
237+
do_x = new_m["do_x"]
238+
draw_x, draw_do_x = pm.draw([x, do_x], draws=5)
239+
np.testing.assert_allclose(draw_x + 100, draw_do_x)
240+
241+
225242
def test_change_value_transforms():
226243
with pm.Model() as base_m:
227244
p = pm.Uniform("p", 0, 1, transform=None)

0 commit comments

Comments
 (0)