-
-
Notifications
You must be signed in to change notification settings - Fork 61
Implement observe
and do
model transformations
#168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
1262a44
Make fgraph Deterministic conversion logic more robust
ricardoV94 5304406
Allow inlining of Deterministics and Data in fgraph IR
ricardoV94 52706e8
Implement observe and do model transformations
ricardoV94 4deb91a
Add option to prune variables after do intervention
ricardoV94 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from pymc import Model | ||
from pytensor.graph import ancestors | ||
|
||
from pymc_experimental.utils.model_fgraph import ( | ||
ModelObservedRV, | ||
ModelVar, | ||
fgraph_from_model, | ||
model_from_fgraph, | ||
) | ||
|
||
|
||
def prune_vars_detached_from_observed(model: Model) -> Model: | ||
"""Prune model variables that are not related to any observed variable in the Model.""" | ||
|
||
# Potentials are ambiguous as whether they correspond to likelihood or prior terms, | ||
# We simply raise for now | ||
if model.potentials: | ||
raise NotImplementedError("Pruning not implemented for models with Potentials") | ||
|
||
fgraph, _ = fgraph_from_model(model, inlined_views=True) | ||
observed_vars = ( | ||
out | ||
for node in fgraph.apply_nodes | ||
if isinstance(node.op, ModelObservedRV) | ||
for out in node.outputs | ||
) | ||
ancestor_nodes = {var.owner for var in ancestors(observed_vars)} | ||
nodes_to_remove = { | ||
node | ||
for node in fgraph.apply_nodes | ||
if isinstance(node.op, ModelVar) and node not in ancestor_nodes | ||
} | ||
for node_to_remove in nodes_to_remove: | ||
fgraph.remove_node(node_to_remove) | ||
return model_from_fgraph(fgraph) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
from typing import Any, Dict, List, Sequence, Union | ||
|
||
from pymc import Model | ||
from pymc.pytensorf import _replace_vars_in_graphs | ||
from pytensor.tensor import TensorVariable | ||
|
||
from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed | ||
from pymc_experimental.utils.model_fgraph import ( | ||
ModelDeterministic, | ||
ModelFreeRV, | ||
extract_dims, | ||
fgraph_from_model, | ||
model_deterministic, | ||
model_from_fgraph, | ||
model_named, | ||
model_observed_rv, | ||
toposort_replace, | ||
) | ||
from pymc_experimental.utils.pytensorf import rvs_in_graph | ||
|
||
|
||
def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model: | ||
"""Convert free RVs or Deterministics to observed RVs. | ||
|
||
Parameters | ||
---------- | ||
model: PyMC Model | ||
vars_to_observations: Dict of variable or name to TensorLike | ||
Dictionary that maps model variables (or names) to observed values. | ||
Observed values must have a shape and data type that is compatible | ||
with the original model variable. | ||
|
||
Returns | ||
------- | ||
new_model: PyMC model | ||
A distinct PyMC model with the relevant variables observed. | ||
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`. | ||
|
||
Examples | ||
-------- | ||
|
||
.. code-block:: python | ||
|
||
import pymc as pm | ||
from pymc_experimental.model_transform.conditioning import observe | ||
|
||
with pm.Model() as m: | ||
x = pm.Normal("x") | ||
y = pm.Normal("y", x) | ||
z = pm.Normal("z", y) | ||
|
||
m_new = observe(m, {y: 0.5}) | ||
|
||
Deterministic variables can also be observed. | ||
This relies on PyMC ability to infer the logp of the underlying expression | ||
|
||
.. code-block:: python | ||
|
||
import pymc as pm | ||
from pymc_experimental.model_transform.conditioning import observe | ||
|
||
with pm.Model() as m: | ||
x = pm.Normal("x") | ||
y = pm.Normal.dist(x, shape=(5,)) | ||
y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1)) | ||
|
||
new_m = observe(m, {y_censored: [0.9, 0.5, 0.3, 1, 1]}) | ||
|
||
|
||
""" | ||
vars_to_observations = { | ||
model[var] if isinstance(var, str) else var: obs | ||
for var, obs in vars_to_observations.items() | ||
} | ||
|
||
valid_model_vars = set(model.free_RVs + model.deterministics) | ||
if any(var not in valid_model_vars for var in vars_to_observations): | ||
raise ValueError(f"At least one var is not a free variable or deterministic in the model") | ||
|
||
fgraph, memo = fgraph_from_model(model) | ||
|
||
replacements = {} | ||
for var, obs in vars_to_observations.items(): | ||
model_var = memo[var] | ||
|
||
# Just a sanity check | ||
assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic)) | ||
assert model_var in fgraph.variables | ||
|
||
var = model_var.owner.inputs[0] | ||
var.name = model_var.name | ||
dims = extract_dims(model_var) | ||
model_obs_rv = model_observed_rv(var, var.type.filter_variable(obs), *dims) | ||
replacements[model_var] = model_obs_rv | ||
|
||
toposort_replace(fgraph, tuple(replacements.items())) | ||
|
||
return model_from_fgraph(fgraph) | ||
|
||
|
||
def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]: | ||
def replacement_fn(var, inner_replacements): | ||
if var in replacements: | ||
inner_replacements[var] = replacements[var] | ||
|
||
# Handle root inputs as those will never be passed to the replacement_fn | ||
for inp in var.owner.inputs: | ||
if inp.owner is None and inp in replacements: | ||
inner_replacements[inp] = replacements[inp] | ||
|
||
return [var] | ||
|
||
replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn) | ||
return replaced_graphs | ||
|
||
|
||
def do( | ||
model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any], prune_vars=False | ||
) -> Model: | ||
"""Replace model variables by intervention variables. | ||
|
||
Intervention variables will either show up as `Data` or `Deterministics` in the new model, | ||
depending on whether they depend on other RandomVariables or not. | ||
|
||
Parameters | ||
---------- | ||
model: PyMC Model | ||
vars_to_interventions: Dict of variable or name to TensorLike | ||
Dictionary that maps model variables (or names) to intervention expressions. | ||
Intervention expressions must have a shape and data type that is compatible | ||
with the original model variable. | ||
prune_vars: bool, defaults to False | ||
Whether to prune model variables that are not connected to any observed variables, | ||
after the interventions. | ||
|
||
Returns | ||
------- | ||
new_model: PyMC model | ||
A distinct PyMC model with the relevant variables replaced by the intervention expressions. | ||
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`. | ||
|
||
Examples | ||
-------- | ||
|
||
.. code-block:: python | ||
|
||
import pymc as pm | ||
from pymc_experimental.model_transform.conditioning import do | ||
|
||
with pm.Model() as m: | ||
x = pm.Normal("x", 0, 1) | ||
y = pm.Normal("y", x, 1) | ||
z = pm.Normal("z", y + x, 1) | ||
|
||
# Dummy posterior, same as calling `pm.sample` | ||
idata_m = az.from_dict({rv.name: [pm.draw(rv, draws=500)] for rv in [x, y, z]}) | ||
|
||
# Replace `y` by a constant `100.0` | ||
m_do = do(m, {y: 100.0}) | ||
with m_do: | ||
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z") | ||
|
||
""" | ||
do_mapping = {} | ||
for var, obs in vars_to_interventions.items(): | ||
if isinstance(var, str): | ||
var = model[var] | ||
try: | ||
do_mapping[var] = var.type.filter_variable(obs) | ||
except TypeError as err: | ||
raise TypeError( | ||
"Incompatible replacement type. Make sure the shape and datatype of the interventions match the original variables" | ||
) from err | ||
|
||
if any(var not in model.named_vars.values() for var in do_mapping): | ||
raise ValueError(f"At least one var is not a named variable in the model") | ||
|
||
fgraph, memo = fgraph_from_model(model, inlined_views=True) | ||
|
||
# We need the interventions defined in terms of the IR fgraph representation, | ||
# In case they reference other variables in the model | ||
ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo) | ||
|
||
replacements = {} | ||
for var, intervention in zip(do_mapping, ir_interventions): | ||
model_var = memo[var] | ||
|
||
# Just a sanity check | ||
assert model_var in fgraph.variables | ||
|
||
intervention.name = model_var.name | ||
dims = extract_dims(model_var) | ||
# If there are any RVs in the graph we introduce the intervention as a deterministic | ||
if rvs_in_graph([intervention]): | ||
new_var = model_deterministic(intervention.copy(name=intervention.name), *dims) | ||
# Otherwise as a named variable (Constant or Shared data) | ||
else: | ||
new_var = model_named(intervention, *dims) | ||
|
||
replacements[model_var] = new_var | ||
|
||
# Replace variables by interventions | ||
toposort_replace(fgraph, tuple(replacements.items())) | ||
|
||
model = model_from_fgraph(fgraph) | ||
if prune_vars: | ||
return prune_vars_detached_from_observed(model) | ||
return model |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import pymc as pm | ||
|
||
from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed | ||
|
||
|
||
def test_prune_vars_detached_from_observed(): | ||
with pm.Model() as m: | ||
obs_data = pm.MutableData("obs_data", 0) | ||
a0 = pm.ConstantData("a0", 0) | ||
a1 = pm.Normal("a1", a0) | ||
a2 = pm.Normal("a2", a1) | ||
pm.Normal("obs", a2, observed=obs_data) | ||
|
||
d0 = pm.ConstantData("d0", 0) | ||
d1 = pm.Normal("d1", d0) | ||
|
||
assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} | ||
pruned_m = prune_vars_detached_from_observed(m) | ||
assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.