Skip to content

Allow do interventions to reference intervened variable #219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pymc_experimental/model_transform/conditioning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import warnings
from typing import Any, List, Mapping, Optional, Sequence, Union

from pymc import Model
from pymc.logprob.transforms import RVTransform
from pymc.pytensorf import _replace_vars_in_graphs
from pymc.util import get_transformed_name, get_untransformed_name
from pytensor.graph import ancestors
from pytensor.tensor import TensorVariable

from pymc_experimental.model_transform.basic import (
Expand Down Expand Up @@ -199,7 +201,15 @@ def do(
# Just a sanity check
assert model_var in fgraph.variables

intervention.name = model_var.name
# If the intervention references the original variable we must give it a different name
if model_var in ancestors([intervention]):
intervention.name = f"do_{model_var.name}"
warnings.warn(
f"Intervention expression references the variable that is being intervened: {model_var.name}. "
f"Intervention will be given the name: {intervention.name}"
)
else:
intervention.name = model_var.name
dims = extract_dims(model_var)
# If there are any RVs in the graph we introduce the intervention as a deterministic
if rvs_in_graph([intervention]):
Expand Down
17 changes: 17 additions & 0 deletions pymc_experimental/tests/model_transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,23 @@ def test_do_prune(prune):
assert set(do_m.named_vars) == orig_named_vars


def test_do_self_reference():
"""Check we can replace a variable by an expression that refers to the same variable."""
with pm.Model() as m:
x = pm.Normal("x", 0, 1)

with pytest.warns(
UserWarning,
match="Intervention expression references the variable that is being intervened",
):
new_m = do(m, {x: x + 100})

x = new_m["x"]
do_x = new_m["do_x"]
draw_x, draw_do_x = pm.draw([x, do_x], draws=5)
np.testing.assert_allclose(draw_x + 100, draw_do_x)


def test_change_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", 0, 1, transform=None)
Expand Down