Skip to content

Commit e507c96

Browse files
committed
Implement observe and do model transformations
1 parent 1abaca8 commit e507c96

File tree

7 files changed

+442
-0
lines changed

7 files changed

+442
-0
lines changed

docs/api_reference.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ Distributions
3535
histogram_approximation
3636

3737

38+
Model Transformations
39+
=====================
40+
41+
.. currentmodule:: pymc_experimental.model_transform
42+
.. autosummary::
43+
:toctree: generated/
44+
45+
conditioning.do
46+
conditioning.observe
47+
48+
3849
Utils
3950
=====
4051

pymc_experimental/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
from typing import Any, Dict, List, Sequence, Union
2+
3+
from pymc import Model
4+
from pymc.pytensorf import _replace_vars_in_graphs
5+
from pytensor.tensor import TensorVariable
6+
7+
from pymc_experimental.utils.model_fgraph import (
8+
ModelDeterministic,
9+
ModelFreeRV,
10+
extract_dims,
11+
fgraph_from_model,
12+
model_deterministic,
13+
model_from_fgraph,
14+
model_named,
15+
model_observed_rv,
16+
toposort_replace,
17+
)
18+
from pymc_experimental.utils.pytensorf import rvs_in_graph
19+
20+
21+
def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
22+
"""Convert free RVs or Deterministics to observed RVs.
23+
24+
Parameters
25+
----------
26+
model: PyMC Model
27+
vars_to_observations: Dict of variable or name to TensorLike
28+
Dictionary that maps model variables (or names) to observed values.
29+
Observed values must have a shape and data type that is compatible
30+
with the original model variable.
31+
32+
Returns
33+
-------
34+
new_model: PyMC model
35+
A distinct PyMC model with the relevant variables observed.
36+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
37+
38+
Examples
39+
--------
40+
41+
.. code-block:: python
42+
43+
import pymc as pm
44+
from pymc_experimental.model_transform.conditioning import observe
45+
46+
with pm.Model() as m:
47+
x = pm.Normal("x")
48+
y = pm.Normal("y", x)
49+
z = pm.Normal("z", y)
50+
51+
m_new = observe(m, {y: 0.5})
52+
53+
Deterministic variables can also be observed.
54+
This relies on PyMC ability to infer the logp of the underlying expression
55+
56+
.. code-block:: python
57+
58+
import pymc as pm
59+
from pymc_experimental.model_transform.conditioning import observe
60+
61+
with pm.Model() as m:
62+
x = pm.Normal("x")
63+
y = pm.Normal.dist(x, shape=(5,))
64+
y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1))
65+
66+
new_m = observe(m, {y_censored: [0.9, 0.5, 0.3, 1, 1]})
67+
68+
69+
"""
70+
vars_to_observations = {
71+
model[var] if isinstance(var, str) else var: obs
72+
for var, obs in vars_to_observations.items()
73+
}
74+
75+
valid_model_vars = set(model.free_RVs + model.deterministics)
76+
if any(var not in valid_model_vars for var in vars_to_observations):
77+
raise ValueError(f"At least one var is not a free variable or deterministic in the model")
78+
79+
fgraph, memo = fgraph_from_model(model)
80+
81+
replacements = {}
82+
for var, obs in vars_to_observations.items():
83+
model_var = memo[var]
84+
85+
# Just a sanity check
86+
assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic))
87+
assert model_var in fgraph.variables
88+
89+
var = model_var.owner.inputs[0]
90+
var.name = model_var.name
91+
dims = extract_dims(model_var)
92+
model_obs_rv = model_observed_rv(var, var.type.filter_variable(obs), *dims)
93+
replacements[model_var] = model_obs_rv
94+
95+
toposort_replace(fgraph, tuple(replacements.items()))
96+
97+
return model_from_fgraph(fgraph)
98+
99+
100+
def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]:
101+
def replacement_fn(var, inner_replacements):
102+
if var in replacements:
103+
inner_replacements[var] = replacements[var]
104+
105+
# Handle root inputs as those will never be passed to the replacement_fn
106+
for inp in var.owner.inputs:
107+
if inp.owner is None and inp in replacements:
108+
inner_replacements[inp] = replacements[inp]
109+
110+
return [var]
111+
112+
replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn)
113+
return replaced_graphs
114+
115+
116+
def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any]) -> Model:
117+
"""Replace model variables by intervention variables.
118+
119+
Intervention variables will either show up as `Data` or `Deterministics` in the new model,
120+
depending on whether they depend on other RandomVariables or not.
121+
122+
Parameters
123+
----------
124+
model: PyMC Model
125+
vars_to_interventions: Dict of variable or name to TensorLike
126+
Dictionary that maps model variables (or names) to intervention expressions.
127+
Intervention expressions must have a shape and data type that is compatible
128+
with the original model variable.
129+
130+
Returns
131+
-------
132+
new_model: PyMC model
133+
A distinct PyMC model with the relevant variables replaced by the intervention expressions.
134+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
135+
136+
Examples
137+
--------
138+
139+
.. code-block:: python
140+
141+
import pymc as pm
142+
from pymc_experimental.model_transform.conditioning import do
143+
144+
with pm.Model() as m:
145+
x = pm.Normal("x", 0, 1)
146+
y = pm.Normal("y", x, 1)
147+
z = pm.Normal("z", y + x, 1)
148+
149+
# Dummy posterior, same as calling `pm.sample`
150+
idata_m = az.from_dict({rv.name: [pm.draw(rv, draws=500)] for rv in [x, y, z]})
151+
152+
# Replace `y` by a constant `100.0`
153+
m_do = do(m, {y: 100.0})
154+
with m_do:
155+
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
156+
157+
"""
158+
do_mapping = {}
159+
for var, obs in vars_to_interventions.items():
160+
if isinstance(var, str):
161+
var = model[var]
162+
try:
163+
do_mapping[var] = var.type.filter_variable(obs)
164+
except TypeError as err:
165+
raise TypeError(
166+
"Incompatible replacement type. Make sure the shape and datatype of the interventions match the original variables"
167+
) from err
168+
169+
if any(var not in model.named_vars.values() for var in do_mapping):
170+
raise ValueError(f"At least one var is not a named variable in the model")
171+
172+
fgraph, memo = fgraph_from_model(model, inlined_views=True)
173+
174+
# We need the interventions defined in terms of the IR fgraph representation,
175+
# In case they reference other variables in the model
176+
ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo)
177+
178+
replacements = {}
179+
for var, intervention in zip(do_mapping, ir_interventions):
180+
model_var = memo[var]
181+
182+
# Just a sanity check
183+
assert model_var in fgraph.variables
184+
185+
intervention.name = model_var.name
186+
dims = extract_dims(model_var)
187+
# If there are any RVs in the graph we introduce the intervention as a deterministic
188+
if rvs_in_graph([intervention]):
189+
new_var = model_deterministic(intervention.copy(name=intervention.name), *dims)
190+
# Otherwise as a named variable (Constant or Shared data)
191+
else:
192+
new_var = model_named(intervention, *dims)
193+
194+
replacements[model_var] = new_var
195+
196+
# Replace variables by interventions
197+
toposort_replace(fgraph, tuple(replacements.items()))
198+
199+
return model_from_fgraph(fgraph)

pymc_experimental/tests/model_transform/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)