-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add graphviz model graphs #3049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1,466 changes: 1,180 additions & 286 deletions
1,466
docs/source/notebooks/multilevel_modeling.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from theano.gof.graph import inputs | ||
|
||
from .util import get_default_varnames | ||
import pymc3 as pm | ||
|
||
|
||
class ModelGraph(object): | ||
def __init__(self, model): | ||
self.model = model | ||
self.var_names = get_default_varnames(self.model.named_vars, include_transformed=False) | ||
self.var_list = self.model.named_vars.values() | ||
self.transform_map = {v.transformed: v.name for v in self.var_list if hasattr(v, 'transformed')} | ||
self._deterministics = None | ||
|
||
def get_deterministics(self, var): | ||
"""Compute the deterministic nodes of the graph""" | ||
deterministics = [] | ||
attrs = ('transformed', 'logpt') | ||
for v in self.var_list: | ||
if v != var and all(not hasattr(v, attr) for attr in attrs): | ||
deterministics.append(v) | ||
return deterministics | ||
|
||
def _inputs(self, var, func, blockers=None): | ||
"""Get inputs to a function that are also named PyMC3 variables""" | ||
return set([j for j in inputs([func], blockers=blockers) if j in self.var_list and j != var]) | ||
|
||
def _get_inputs(self, var, func): | ||
"""Get all inputs to a function, doing some accounting for deterministics | ||
|
||
Specifically, if a deterministic is an input, theano.gof.graph.inputs will | ||
return only the inputs *to the deterministic*. However, if we pass in the | ||
deterministic as a blocker, it will skip those nodes. | ||
""" | ||
deterministics = self.get_deterministics(var) | ||
upstream = self._inputs(var, func) | ||
parents = self._inputs(var, func, blockers=deterministics) | ||
if parents != upstream: | ||
det_map = {} | ||
for d in deterministics: | ||
d_set = {j for j in inputs([func], blockers=[d])} | ||
if upstream - d_set: | ||
det_map[d] = d_set | ||
for d, d_set in det_map.items(): | ||
if all(d_set.issubset(other) for other in det_map.values()): | ||
parents.add(d) | ||
return parents | ||
|
||
def _filter_parents(self, var, parents): | ||
"""Get direct parents of a var, as strings""" | ||
keep = set() | ||
for p in parents: | ||
if p == var: | ||
continue | ||
elif p.name in self.var_names: | ||
keep.add(p.name) | ||
elif p in self.transform_map: | ||
if self.transform_map[p] != var.name: | ||
keep.add(self.transform_map[p]) | ||
else: | ||
raise AssertionError('Do not know what to do with {}'.format(str(p))) | ||
return keep | ||
|
||
def get_parents(self, var): | ||
"""Get the named nodes that are direct inputs to the var""" | ||
if hasattr(var, 'transformed'): | ||
func = var.transformed.logpt | ||
elif hasattr(var, 'logpt'): | ||
func = var.logpt | ||
else: | ||
func = var | ||
|
||
parents = self._get_inputs(var, func) | ||
return self._filter_parents(var, parents) | ||
|
||
def make_compute_graph(self): | ||
"""Get map of var_name -> set(input var names) for the model""" | ||
input_map = {} | ||
for var_name in self.var_names: | ||
input_map[var_name] = self.get_parents(self.model[var_name]) | ||
return input_map | ||
|
||
def _make_node(self, var_name, graph): | ||
"""Attaches the given variable to a graphviz Digraph""" | ||
v = self.model[var_name] | ||
|
||
# styling for node | ||
attrs = {} | ||
if isinstance(v, pm.model.ObservedRV): | ||
attrs['style'] = 'filled' | ||
|
||
# Get name for node | ||
if hasattr(v, 'distribution'): | ||
distribution = v.distribution.__class__.__name__ | ||
else: | ||
distribution = 'Deterministic' | ||
attrs['shape'] = 'box' | ||
|
||
graph.node(var_name, | ||
'{var_name} ~ {distribution}'.format(var_name=var_name, distribution=distribution), | ||
**attrs) | ||
|
||
def get_plates(self): | ||
""" Rough but surprisingly accurate plate detection. | ||
|
||
Just groups by the shape of the underlying distribution. Will be wrong | ||
if there are two plates with the same shape. | ||
|
||
Returns | ||
------- | ||
dict: str -> set[str] | ||
""" | ||
plates = {} | ||
for var_name in self.var_names: | ||
v = self.model[var_name] | ||
if hasattr(v, 'observations'): | ||
shape = v.observations.shape | ||
elif hasattr(v, 'dshape'): | ||
shape = v.dshape | ||
else: | ||
shape = v.tag.test_value.shape | ||
if shape == (1,): | ||
shape = tuple() | ||
if shape not in plates: | ||
plates[shape] = set() | ||
plates[shape].add(var_name) | ||
return plates | ||
|
||
def make_graph(self): | ||
"""Make graphviz Digraph of PyMC3 model | ||
|
||
Returns | ||
------- | ||
graphviz.Digraph | ||
""" | ||
try: | ||
import graphviz | ||
except ImportError: | ||
raise ImportError('This function requires the python library graphviz, along with binaries. ' | ||
'The easiest way to install all of this is by running\n\n' | ||
'\tconda install -c conda-forge python-graphviz') | ||
graph = graphviz.Digraph(self.model.name) | ||
for shape, var_names in self.get_plates().items(): | ||
label = ' x '.join(map('{:,d}'.format, shape)) | ||
if label: | ||
# must be preceded by 'cluster' to get a box around it | ||
with graph.subgraph(name='cluster' + label) as sub: | ||
for var_name in var_names: | ||
self._make_node(var_name, sub) | ||
# plate label goes bottom right | ||
sub.attr(label=label, labeljust='r', labelloc='b', style='rounded') | ||
else: | ||
for var_name in var_names: | ||
self._make_node(var_name, graph) | ||
|
||
for key, values in self.make_compute_graph().items(): | ||
for value in values: | ||
graph.edge(value, key) | ||
return graph | ||
|
||
|
||
def model_to_graphviz(model=None): | ||
"""Produce a graphviz Digraph from a PyMC3 model. | ||
|
||
Requires graphviz, which may be installed most easily with | ||
conda install -c conda-forge python-graphviz | ||
|
||
Alternatively, you may install the `graphviz` binaries yourself, | ||
and then `pip install graphviz` to get the python bindings. See | ||
http://graphviz.readthedocs.io/en/stable/manual.html | ||
for more information. | ||
""" | ||
model = pm.modelcontext(model) | ||
return ModelGraph(model).make_graph() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import numpy as np | ||
import pymc3 as pm | ||
from pymc3.model_graph import ModelGraph, model_to_graphviz | ||
|
||
from .helpers import SeededTest | ||
|
||
|
||
def radon_model(): | ||
"""Similar in shape to the Radon model""" | ||
n_homes = 919 | ||
counties = 85 | ||
uranium = np.random.normal(-.1, 0.4, size=n_homes) | ||
xbar = np.random.normal(1, 0.1, size=n_homes) | ||
floor_measure = np.random.randint(0, 2, size=n_homes) | ||
log_radon = np.random.normal(1, 1, size=n_homes) | ||
|
||
d, r = divmod(919, 85) | ||
county = np.hstack(( | ||
np.tile(np.arange(counties, dtype=int), d), | ||
np.arange(r) | ||
)) | ||
with pm.Model() as model: | ||
sigma_a = pm.HalfCauchy('sigma_a', 5) | ||
gamma = pm.Normal('gamma', mu=0., sd=1e5, shape=3) | ||
mu_a = pm.Deterministic('mu_a', gamma[0] + gamma[1]*uranium + gamma[2]*xbar) | ||
eps_a = pm.Normal('eps_a', mu=0, sd=sigma_a, shape=counties) | ||
a = pm.Deterministic('a', mu_a + eps_a[county]) | ||
b = pm.Normal('b', mu=0., sd=1e15) | ||
sigma_y = pm.Uniform('sigma_y', lower=0, upper=100) | ||
y_hat = a + b * floor_measure | ||
y_like = pm.Normal('y_like', mu=y_hat, sd=sigma_y, observed=log_radon) | ||
|
||
compute_graph = { | ||
'sigma_a': set(), | ||
'gamma': set(), | ||
'mu_a': {'gamma'}, | ||
'eps_a': {'sigma_a'}, | ||
'a': {'mu_a', 'eps_a'}, | ||
'b': set(), | ||
'sigma_y': set(), | ||
'y_like': {'a', 'b', 'sigma_y'} | ||
} | ||
plates = { | ||
(): {'b', 'sigma_a', 'sigma_y'}, | ||
(3,): {'gamma'}, | ||
(85,): {'eps_a'}, | ||
(919,): {'a', 'mu_a', 'y_like'}, | ||
} | ||
return model, compute_graph, plates | ||
|
||
|
||
class TestSimpleModel(SeededTest): | ||
@classmethod | ||
def setup_class(cls): | ||
cls.model, cls.compute_graph, cls.plates = radon_model() | ||
cls.model_graph = ModelGraph(cls.model) | ||
|
||
def test_inputs(self): | ||
for child, parents in self.compute_graph.items(): | ||
var = self.model[child] | ||
found_parents = self.model_graph.get_parents(var) | ||
assert found_parents == parents | ||
|
||
def test_compute_graph(self): | ||
assert self.compute_graph == self.model_graph.make_compute_graph() | ||
|
||
def test_plates(self): | ||
assert self.plates == self.model_graph.get_plates() | ||
|
||
def test_graphviz(self): | ||
# just make sure everything runs without error | ||
|
||
g = self.model_graph.make_graph() | ||
for key in self.compute_graph: | ||
assert key in g.source | ||
g = model_to_graphviz(self.model) | ||
for key in self.compute_graph: | ||
assert key in g.source | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
bokeh>=0.12.13 | ||
CommonMark==0.5.4 | ||
graphviz>=0.8.3 | ||
h5py>=2.7.0 | ||
ipython | ||
Keras>=2.0.8 | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need newline