Skip to content

Commit 4ef10c9

Browse files
committed
Implement observe and do model transformations
1 parent 08a60ae commit 4ef10c9

File tree

5 files changed

+184
-0
lines changed

5 files changed

+184
-0
lines changed

pymc_experimental/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import Any, Dict, List, Sequence, Union
2+
3+
from pymc import Model
4+
from pymc.pytensorf import _replace_rvs_in_graphs
5+
from pytensor.tensor import TensorVariable
6+
7+
from pymc_experimental.utils.model_fgraph import (
8+
ModelFreeRV,
9+
extract_dims,
10+
fgraph_from_model,
11+
model_from_fgraph,
12+
model_named,
13+
model_observed_rv,
14+
toposort_replace,
15+
)
16+
17+
18+
def observe(model: Model, vars_to_observations: Dict[Union["str"], Any]) -> Model:
19+
"""Convert free RVS to observed RVs.
20+
21+
This function returns a new model.
22+
"""
23+
vars_to_observations = {
24+
model[var] if isinstance(var, str) else var: obs
25+
for var, obs in vars_to_observations.items()
26+
}
27+
28+
# Note: Since PyMC can infer logprob expressions we could also allow observing Deterministics
29+
if any(var not in model.free_RVs for var in vars_to_observations):
30+
raise ValueError(f"At least one var is not a free variable in the model")
31+
32+
fgraph, memo = fgraph_from_model(model)
33+
34+
replacements = {}
35+
for var, obs in vars_to_observations.items():
36+
model_free_rv = memo[var]
37+
38+
# Just a sanity check
39+
assert isinstance(model_free_rv.owner.op, ModelFreeRV)
40+
assert model_free_rv in fgraph.variables
41+
42+
rv, vv, *dims = model_free_rv.owner.inputs
43+
model_obs_rv = model_observed_rv(rv, rv.type.filter_variable(obs), *dims)
44+
replacements[model_free_rv] = model_obs_rv
45+
46+
toposort_replace(fgraph, tuple(replacements.items()))
47+
48+
return model_from_fgraph(fgraph)
49+
50+
51+
def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]:
52+
def replacement_fn(var, inner_replacements):
53+
if var in replacements:
54+
inner_replacements[var] = replacements[var]
55+
56+
# Handle root inputs as those will never be passed to the replacement_fn
57+
for inp in var.owner.inputs:
58+
if inp.owner is None and inp in replacements:
59+
inner_replacements[inp] = replacements[inp]
60+
61+
return [var]
62+
63+
replaced_graphs, _ = _replace_rvs_in_graphs(graphs=graphs, replacement_fn=replacement_fn)
64+
return replaced_graphs
65+
66+
67+
def do(model: Model, vars_to_interventions: Dict[Union["str"], Any]) -> Model:
68+
"""Replace model variables by intervention variables.
69+
70+
This function returns a new model
71+
"""
72+
do_mapping = {}
73+
for var, obs in vars_to_interventions.items():
74+
if isinstance(var, str):
75+
var = model[var]
76+
do_mapping[var] = var.type.filter_variable(obs)
77+
78+
if any(var not in (model.basic_RVs + model.deterministics) for var in do_mapping):
79+
raise ValueError(f"At least one var is not a variable or deterministic in the model")
80+
81+
fgraph, memo = fgraph_from_model(model)
82+
83+
# We need the interventions defined in terms of the IR fgraph representation,
84+
# In case they reference other variables in the model
85+
ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo)
86+
87+
replacements = {}
88+
for var, intervention in zip(do_mapping, ir_interventions):
89+
model_var = memo[var]
90+
91+
# Just a sanity check
92+
assert model_var in fgraph.variables
93+
94+
intervention.name = model_var.name
95+
dims = extract_dims(model_var)
96+
new_var = model_named(intervention, *dims)
97+
98+
replacements[model_var] = new_var
99+
100+
# Replace variables by interventions
101+
toposort_replace(fgraph, tuple(replacements.items()))
102+
103+
return model_from_fgraph(fgraph)

pymc_experimental/tests/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import numpy as np
2+
import pymc as pm
3+
4+
from pymc_experimental.model_transform.conditioning import do, observe
5+
6+
7+
def test_observe():
8+
with pm.Model() as m_old:
9+
x = pm.Normal("x")
10+
y = pm.Normal("y", x)
11+
z = pm.Normal("z", y)
12+
13+
m_new = observe(m_old, {y: 0.5})
14+
15+
assert len(m_new.free_RVs) == 2
16+
assert len(m_new.observed_RVs) == 1
17+
assert m_new["x"] in m_new.free_RVs
18+
assert m_new["y"] in m_new.observed_RVs
19+
assert m_new["z"] in m_new.free_RVs
20+
21+
np.testing.assert_allclose(
22+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
23+
m_new.compile_logp()({"x": 0.9, "z": 1.4}),
24+
)
25+
26+
# Test two substitutions
27+
m_new = observe(m_old, {y: 0.5, z: 1.4})
28+
29+
assert len(m_new.free_RVs) == 1
30+
assert len(m_new.observed_RVs) == 2
31+
assert m_new["x"] in m_new.free_RVs
32+
assert m_new["y"] in m_new.observed_RVs
33+
assert m_new["z"] in m_new.observed_RVs
34+
35+
np.testing.assert_allclose(
36+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
37+
m_new.compile_logp()({"x": 0.9}),
38+
)
39+
40+
41+
def test_do():
42+
with pm.Model() as m_old:
43+
x = pm.Normal("x", 0, 1e-3)
44+
y = pm.Normal("y", x, 1e-3)
45+
z = pm.Normal("z", y + x, 1e-3)
46+
47+
assert -5 < pm.draw(z) < 5
48+
49+
m_new = do(m_old, {y: x + 100})
50+
51+
assert len(m_new.free_RVs) == 2
52+
assert m_new["x"] in m_new.free_RVs
53+
assert m_new["y"] in m_new.named_vars.values()
54+
assert m_new["z"] in m_new.free_RVs
55+
56+
assert 95 < pm.draw(m_new["z"]) < 105
57+
58+
# Test two substitutions
59+
with m_old:
60+
switch = pm.MutableData("switch", 1)
61+
m_new = do(m_old, {y: 100 * switch, x: 100 * switch})
62+
63+
assert len(m_new.free_RVs) == 1
64+
assert m_new["x"] in m_new.named_vars.values()
65+
assert m_new["y"] in m_new.named_vars.values()
66+
assert m_new["z"] in m_new.free_RVs
67+
68+
assert 195 < pm.draw(m_new["z"]) < 205
69+
with m_new:
70+
pm.set_data({"switch": 0})
71+
assert -5 < pm.draw(m_new["z"]) < 5

pymc_experimental/utils/model_fgraph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,13 @@ def clone_model(model: Model) -> Tuple[Model]:
326326
327327
"""
328328
return model_from_fgraph(fgraph_from_model(model)[0])
329+
330+
331+
def extract_dims(var) -> Tuple:
332+
dims = ()
333+
if isinstance(var, ModelVar):
334+
if isinstance(var, ModelValuedVar):
335+
dims = var.inputs[2:]
336+
else:
337+
dims = var.inputs[1:]
338+
return dims

0 commit comments

Comments
 (0)