Skip to content

Commit ffd21e7

Browse files
committed
Implement observe and do model transformations
1 parent 1d928af commit ffd21e7

File tree

6 files changed

+349
-0
lines changed

6 files changed

+349
-0
lines changed

docs/api_reference.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ Distributions
3535
histogram_approximation
3636

3737

38+
Model Transformations
39+
=====================
40+
41+
.. currentmodule:: pymc_experimental.model_transform
42+
.. autosummary::
43+
:toctree: generated/
44+
45+
conditioning.do
46+
conditioning.observe
47+
48+
3849
Utils
3950
=====
4051

pymc_experimental/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from typing import Any, Dict, List, Sequence, Union
2+
3+
from pymc import Model
4+
from pymc.pytensorf import _replace_vars_in_graphs
5+
from pytensor.tensor import TensorVariable
6+
7+
from pymc_experimental.utils.model_fgraph import (
8+
ModelDeterministic,
9+
ModelFreeRV,
10+
extract_dims,
11+
fgraph_from_model,
12+
model_from_fgraph,
13+
model_named,
14+
model_observed_rv,
15+
toposort_replace,
16+
)
17+
18+
19+
def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
20+
"""Convert free RVs or Deterministics to observed RVs.
21+
22+
Parameters
23+
----------
24+
model: PyMC Model
25+
vars_to_observations: Dict of variable or name to TensorLike
26+
Dictionary that maps model variables (or names) to observed values.
27+
Observed values must have a shape and data type that is compatible
28+
with the original model variable.
29+
30+
Returns
31+
-------
32+
new_model: PyMC model
33+
A distinct PyMC model with the relevant variables observed.
34+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
35+
36+
Examples
37+
--------
38+
39+
.. code-block:: python
40+
41+
import pymc as pm
42+
from pymc_experimental.model_transform.conditioning import observe
43+
44+
with pm.Model() as m:
45+
x = pm.Normal("x")
46+
y = pm.Normal("y", x)
47+
z = pm.Normal("z", y)
48+
49+
m_new = observe(m, {y: 0.5})
50+
51+
Deterministic variables can also be observed.
52+
This relies on PyMC ability to infer the logp of the underlying expression
53+
54+
.. code-block:: python
55+
56+
import pymc as pm
57+
from pymc_experimental.model_transform.conditioning import observe
58+
59+
with pm.Model() as m:
60+
x = pm.Normal("x")
61+
y = pm.Normal.dist(x, shape=(5,))
62+
y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1))
63+
64+
new_m = observe(m, {y_censored: [0.9, 0.5, 0.3, 1, 1]})
65+
66+
67+
"""
68+
vars_to_observations = {
69+
model[var] if isinstance(var, str) else var: obs
70+
for var, obs in vars_to_observations.items()
71+
}
72+
73+
# Note: Since PyMC can infer logprob expressions we could also allow observing Deterministics
74+
valid_model_vars = set(model.free_RVs + model.deterministics)
75+
if any(var not in valid_model_vars for var in vars_to_observations):
76+
raise ValueError(f"At least one var is not a free variable or deterministic in the model")
77+
78+
fgraph, memo = fgraph_from_model(model)
79+
80+
replacements = {}
81+
for var, obs in vars_to_observations.items():
82+
model_var = memo[var]
83+
84+
# Just a sanity check
85+
assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic))
86+
assert model_var in fgraph.variables
87+
88+
var = model_var.owner.inputs[0]
89+
var.name = model_var.name
90+
dims = extract_dims(var)
91+
model_obs_rv = model_observed_rv(var, var.type.filter_variable(obs), *dims)
92+
replacements[model_var] = model_obs_rv
93+
94+
toposort_replace(fgraph, tuple(replacements.items()))
95+
96+
return model_from_fgraph(fgraph)
97+
98+
99+
def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]:
100+
def replacement_fn(var, inner_replacements):
101+
if var in replacements:
102+
inner_replacements[var] = replacements[var]
103+
104+
# Handle root inputs as those will never be passed to the replacement_fn
105+
for inp in var.owner.inputs:
106+
if inp.owner is None and inp in replacements:
107+
inner_replacements[inp] = replacements[inp]
108+
109+
return [var]
110+
111+
replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn)
112+
return replaced_graphs
113+
114+
115+
def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any]) -> Model:
116+
"""Replace model variables by intervention variables.
117+
118+
Parameters
119+
----------
120+
model: PyMC Model
121+
vars_to_interventions: Dict of variable or name to TensorLike
122+
Dictionary that maps model variables (or names) to intervention expressions.
123+
Intervention expressions must have a shape and data type that is compatible
124+
with the original model variable.
125+
126+
Returns
127+
-------
128+
new_model: PyMC model
129+
A distinct PyMC model with the relevant variables replaced by the intervention expressions.
130+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
131+
132+
Examples
133+
--------
134+
135+
.. code-block:: python
136+
137+
import pymc as pm
138+
from pymc_experimental.model_transform.conditioning import do
139+
140+
with pm.Model() as m:
141+
x = pm.Normal("x", 0, 1)
142+
y = pm.Normal("y", x, 1)
143+
z = pm.Normal("z", y + x, 1)
144+
145+
# Dummy posterior, same as calling `pm.sample`
146+
idata_m = az.from_dict({rv.name: [pm.draw(rv, draws=500)] for rv in [x, y, z]})
147+
148+
# Replace `y` by a constant `100.0`
149+
m_do = do(m, {y: 100.0})
150+
with m_do:
151+
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
152+
153+
"""
154+
do_mapping = {}
155+
for var, obs in vars_to_interventions.items():
156+
if isinstance(var, str):
157+
var = model[var]
158+
do_mapping[var] = var.type.filter_variable(obs)
159+
160+
if any(var not in (model.basic_RVs + model.deterministics) for var in do_mapping):
161+
raise ValueError(f"At least one var is not a variable or deterministic in the model")
162+
163+
fgraph, memo = fgraph_from_model(model)
164+
165+
# We need the interventions defined in terms of the IR fgraph representation,
166+
# In case they reference other variables in the model
167+
ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo)
168+
169+
replacements = {}
170+
for var, intervention in zip(do_mapping, ir_interventions):
171+
model_var = memo[var]
172+
173+
# Just a sanity check
174+
assert model_var in fgraph.variables
175+
176+
intervention.name = model_var.name
177+
dims = extract_dims(model_var)
178+
new_var = model_named(intervention, *dims)
179+
180+
replacements[model_var] = new_var
181+
182+
# Replace variables by interventions
183+
toposort_replace(fgraph, tuple(replacements.items()))
184+
185+
return model_from_fgraph(fgraph)

pymc_experimental/tests/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import arviz as az
2+
import numpy as np
3+
import pymc as pm
4+
from pymc.variational.minibatch_rv import create_minibatch_rv
5+
from pytensor import config
6+
7+
from pymc_experimental.model_transform.conditioning import do, observe
8+
9+
10+
def test_observe():
11+
with pm.Model() as m_old:
12+
x = pm.Normal("x")
13+
y = pm.Normal("y", x)
14+
z = pm.Normal("z", y)
15+
16+
m_new = observe(m_old, {y: 0.5})
17+
18+
assert len(m_new.free_RVs) == 2
19+
assert len(m_new.observed_RVs) == 1
20+
assert m_new["x"] in m_new.free_RVs
21+
assert m_new["y"] in m_new.observed_RVs
22+
assert m_new["z"] in m_new.free_RVs
23+
24+
np.testing.assert_allclose(
25+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
26+
m_new.compile_logp()({"x": 0.9, "z": 1.4}),
27+
)
28+
29+
# Test two substitutions
30+
m_new = observe(m_old, {y: 0.5, z: 1.4})
31+
32+
assert len(m_new.free_RVs) == 1
33+
assert len(m_new.observed_RVs) == 2
34+
assert m_new["x"] in m_new.free_RVs
35+
assert m_new["y"] in m_new.observed_RVs
36+
assert m_new["z"] in m_new.observed_RVs
37+
38+
np.testing.assert_allclose(
39+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
40+
m_new.compile_logp()({"x": 0.9}),
41+
)
42+
43+
44+
def test_observe_minibatch():
45+
data = np.zeros((100,), dtype=config.floatX)
46+
batch_size = 10
47+
with pm.Model() as m_old:
48+
x = pm.Normal("x")
49+
y = pm.Normal("y", x)
50+
# Minibatch RVs are usually created with `total_size` kwarg
51+
z_raw = pm.Normal.dist(y, shape=batch_size)
52+
mb_z = create_minibatch_rv(z_raw, total_size=data.shape)
53+
m_old.register_rv(mb_z, name="mb_z")
54+
55+
mb_data = pm.Minibatch(data, batch_size=batch_size)
56+
m_new = observe(m_old, {mb_z: mb_data})
57+
58+
assert len(m_new.free_RVs) == 2
59+
assert len(m_new.observed_RVs) == 1
60+
assert m_new["x"] in m_new.free_RVs
61+
assert m_new["y"] in m_new.free_RVs
62+
assert m_new["mb_z"] in m_new.observed_RVs
63+
64+
np.testing.assert_allclose(
65+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "mb_z": np.zeros(10)}),
66+
m_new.compile_logp()({"x": 0.9, "y": 0.5}),
67+
)
68+
69+
70+
def test_observe_deterministic():
71+
y_censored_obs = np.array([0.9, 0.5, 0.3, 1, 1], dtype=config.floatX)
72+
73+
with pm.Model() as m_old:
74+
x = pm.Normal("x")
75+
y = pm.Normal.dist(x, shape=(5,))
76+
y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1))
77+
78+
m_new = observe(m_old, {y_censored: y_censored_obs})
79+
80+
with pm.Model() as m_ref:
81+
x = pm.Normal("x")
82+
pm.Censored("y_censored", pm.Normal.dist(x), lower=-1, upper=1, observed=y_censored_obs)
83+
84+
np.testing.assert_allclose(
85+
m_new.compile_logp()({"x": 0.9}),
86+
m_ref.compile_logp()({"x": 0.9}),
87+
)
88+
89+
90+
def test_do():
91+
with pm.Model() as m_old:
92+
x = pm.Normal("x", 0, 1e-3)
93+
y = pm.Normal("y", x, 1e-3)
94+
z = pm.Normal("z", y + x, 1e-3)
95+
96+
assert -5 < pm.draw(z) < 5
97+
98+
m_new = do(m_old, {y: x + 100})
99+
100+
assert len(m_new.free_RVs) == 2
101+
assert m_new["x"] in m_new.free_RVs
102+
assert m_new["y"] in m_new.named_vars.values()
103+
assert m_new["z"] in m_new.free_RVs
104+
105+
assert 95 < pm.draw(m_new["z"]) < 105
106+
107+
# Test two substitutions
108+
with m_old:
109+
switch = pm.MutableData("switch", 1)
110+
m_new = do(m_old, {y: 100 * switch, x: 100 * switch})
111+
112+
assert len(m_new.free_RVs) == 1
113+
assert m_new["x"] in m_new.named_vars.values()
114+
assert m_new["y"] in m_new.named_vars.values()
115+
assert m_new["z"] in m_new.free_RVs
116+
117+
assert 195 < pm.draw(m_new["z"]) < 205
118+
with m_new:
119+
pm.set_data({"switch": 0})
120+
assert -5 < pm.draw(m_new["z"]) < 5
121+
122+
123+
def test_do_posterior_predictive():
124+
with pm.Model() as m:
125+
x = pm.Normal("x", 0, 1)
126+
y = pm.Normal("y", x, 1)
127+
z = pm.Normal("z", y + x, 1e-3)
128+
129+
# Dummy posterior
130+
idata_m = az.from_dict(
131+
{
132+
"x": np.full((2, 500), 25),
133+
"y": np.full((2, 500), np.nan),
134+
"z": np.full((2, 500), np.nan),
135+
}
136+
)
137+
138+
# Replace `y` by a constant `100.0`
139+
m_do = do(m, {y: 100.0})
140+
with m_do:
141+
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
142+
143+
assert 120 < idata_do.posterior_predictive["z"].mean() < 130

pymc_experimental/utils/model_fgraph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,13 @@ def clone_model(model: Model) -> Tuple[Model]:
326326
327327
"""
328328
return model_from_fgraph(fgraph_from_model(model)[0])
329+
330+
331+
def extract_dims(var) -> Tuple:
332+
dims = ()
333+
if isinstance(var, ModelVar):
334+
if isinstance(var, ModelValuedVar):
335+
dims = var.inputs[2:]
336+
else:
337+
dims = var.inputs[1:]
338+
return dims

0 commit comments

Comments
 (0)