Skip to content

Add graphviz model graphs #3049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
the unconditioned model.
- SMC: remove experimental warning, allow sampling using `sample`, reduce autocorrelation from
final trace.
- Add `model_to_graphviz` (which uses the optional dependency `graphviz`) to
plot a directed graph of a PyMC3 model using plate notation.

### Fixes

Expand Down
1,466 changes: 1,180 additions & 286 deletions docs/source/notebooks/multilevel_modeling.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import gp
from .math import logaddexp, logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
from .model import *
from .model_graph import model_to_graphviz
from .stats import *
from .sampling import *
from .step_methods import *
Expand Down
174 changes: 174 additions & 0 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from theano.gof.graph import inputs

from .util import get_default_varnames
import pymc3 as pm


class ModelGraph(object):
def __init__(self, model):
self.model = model
self.var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
self.var_list = self.model.named_vars.values()
self.transform_map = {v.transformed: v.name for v in self.var_list if hasattr(v, 'transformed')}
self._deterministics = None

def get_deterministics(self, var):
"""Compute the deterministic nodes of the graph"""
deterministics = []
attrs = ('transformed', 'logpt')
for v in self.var_list:
if v != var and all(not hasattr(v, attr) for attr in attrs):
deterministics.append(v)
return deterministics

def _inputs(self, var, func, blockers=None):
"""Get inputs to a function that are also named PyMC3 variables"""
return set([j for j in inputs([func], blockers=blockers) if j in self.var_list and j != var])

def _get_inputs(self, var, func):
"""Get all inputs to a function, doing some accounting for deterministics

Specifically, if a deterministic is an input, theano.gof.graph.inputs will
return only the inputs *to the deterministic*. However, if we pass in the
deterministic as a blocker, it will skip those nodes.
"""
deterministics = self.get_deterministics(var)
upstream = self._inputs(var, func)
parents = self._inputs(var, func, blockers=deterministics)
if parents != upstream:
det_map = {}
for d in deterministics:
d_set = {j for j in inputs([func], blockers=[d])}
if upstream - d_set:
det_map[d] = d_set
for d, d_set in det_map.items():
if all(d_set.issubset(other) for other in det_map.values()):
parents.add(d)
return parents

def _filter_parents(self, var, parents):
"""Get direct parents of a var, as strings"""
keep = set()
for p in parents:
if p == var:
continue
elif p.name in self.var_names:
keep.add(p.name)
elif p in self.transform_map:
if self.transform_map[p] != var.name:
keep.add(self.transform_map[p])
else:
raise AssertionError('Do not know what to do with {}'.format(str(p)))
return keep

def get_parents(self, var):
"""Get the named nodes that are direct inputs to the var"""
if hasattr(var, 'transformed'):
func = var.transformed.logpt
elif hasattr(var, 'logpt'):
func = var.logpt
else:
func = var

parents = self._get_inputs(var, func)
return self._filter_parents(var, parents)

def make_compute_graph(self):
"""Get map of var_name -> set(input var names) for the model"""
input_map = {}
for var_name in self.var_names:
input_map[var_name] = self.get_parents(self.model[var_name])
return input_map

def _make_node(self, var_name, graph):
"""Attaches the given variable to a graphviz Digraph"""
v = self.model[var_name]

# styling for node
attrs = {}
if isinstance(v, pm.model.ObservedRV):
attrs['style'] = 'filled'

# Get name for node
if hasattr(v, 'distribution'):
distribution = v.distribution.__class__.__name__
else:
distribution = 'Deterministic'
attrs['shape'] = 'box'

graph.node(var_name,
'{var_name} ~ {distribution}'.format(var_name=var_name, distribution=distribution),
**attrs)

def get_plates(self):
""" Rough but surprisingly accurate plate detection.

Just groups by the shape of the underlying distribution. Will be wrong
if there are two plates with the same shape.

Returns
-------
dict: str -> set[str]
"""
plates = {}
for var_name in self.var_names:
v = self.model[var_name]
if hasattr(v, 'observations'):
shape = v.observations.shape
elif hasattr(v, 'dshape'):
shape = v.dshape
else:
shape = v.tag.test_value.shape
if shape == (1,):
shape = tuple()
if shape not in plates:
plates[shape] = set()
plates[shape].add(var_name)
return plates

def make_graph(self):
"""Make graphviz Digraph of PyMC3 model

Returns
-------
graphviz.Digraph
"""
try:
import graphviz
except ImportError:
raise ImportError('This function requires the python library graphviz, along with binaries. '
'The easiest way to install all of this is by running\n\n'
'\tconda install -c conda-forge python-graphviz')
graph = graphviz.Digraph(self.model.name)
for shape, var_names in self.get_plates().items():
label = ' x '.join(map('{:,d}'.format, shape))
if label:
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name='cluster' + label) as sub:
for var_name in var_names:
self._make_node(var_name, sub)
# plate label goes bottom right
sub.attr(label=label, labeljust='r', labelloc='b', style='rounded')
else:
for var_name in var_names:
self._make_node(var_name, graph)

for key, values in self.make_compute_graph().items():
for value in values:
graph.edge(value, key)
return graph


def model_to_graphviz(model=None):
"""Produce a graphviz Digraph from a PyMC3 model.

Requires graphviz, which may be installed most easily with
conda install -c conda-forge python-graphviz

Alternatively, you may install the `graphviz` binaries yourself,
and then `pip install graphviz` to get the python bindings. See
http://graphviz.readthedocs.io/en/stable/manual.html
for more information.
"""
model = pm.modelcontext(model)
return ModelGraph(model).make_graph()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need newline

79 changes: 79 additions & 0 deletions pymc3/tests/test_model_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import pymc3 as pm
from pymc3.model_graph import ModelGraph, model_to_graphviz

from .helpers import SeededTest


def radon_model():
"""Similar in shape to the Radon model"""
n_homes = 919
counties = 85
uranium = np.random.normal(-.1, 0.4, size=n_homes)
xbar = np.random.normal(1, 0.1, size=n_homes)
floor_measure = np.random.randint(0, 2, size=n_homes)
log_radon = np.random.normal(1, 1, size=n_homes)

d, r = divmod(919, 85)
county = np.hstack((
np.tile(np.arange(counties, dtype=int), d),
np.arange(r)
))
with pm.Model() as model:
sigma_a = pm.HalfCauchy('sigma_a', 5)
gamma = pm.Normal('gamma', mu=0., sd=1e5, shape=3)
mu_a = pm.Deterministic('mu_a', gamma[0] + gamma[1]*uranium + gamma[2]*xbar)
eps_a = pm.Normal('eps_a', mu=0, sd=sigma_a, shape=counties)
a = pm.Deterministic('a', mu_a + eps_a[county])
b = pm.Normal('b', mu=0., sd=1e15)
sigma_y = pm.Uniform('sigma_y', lower=0, upper=100)
y_hat = a + b * floor_measure
y_like = pm.Normal('y_like', mu=y_hat, sd=sigma_y, observed=log_radon)

compute_graph = {
'sigma_a': set(),
'gamma': set(),
'mu_a': {'gamma'},
'eps_a': {'sigma_a'},
'a': {'mu_a', 'eps_a'},
'b': set(),
'sigma_y': set(),
'y_like': {'a', 'b', 'sigma_y'}
}
plates = {
(): {'b', 'sigma_a', 'sigma_y'},
(3,): {'gamma'},
(85,): {'eps_a'},
(919,): {'a', 'mu_a', 'y_like'},
}
return model, compute_graph, plates


class TestSimpleModel(SeededTest):
@classmethod
def setup_class(cls):
cls.model, cls.compute_graph, cls.plates = radon_model()
cls.model_graph = ModelGraph(cls.model)

def test_inputs(self):
for child, parents in self.compute_graph.items():
var = self.model[child]
found_parents = self.model_graph.get_parents(var)
assert found_parents == parents

def test_compute_graph(self):
assert self.compute_graph == self.model_graph.make_compute_graph()

def test_plates(self):
assert self.plates == self.model_graph.get_plates()

def test_graphviz(self):
# just make sure everything runs without error

g = self.model_graph.make_graph()
for key in self.compute_graph:
assert key in g.source
g = model_to_graphviz(self.model)
for key in self.compute_graph:
assert key in g.source

1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
bokeh>=0.12.13
CommonMark==0.5.4
graphviz>=0.8.3
h5py>=2.7.0
ipython
Keras>=2.0.8
Expand Down
1 change: 1 addition & 0 deletions scripts/create_testenv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ then
fi
fi
conda install --yes numpy scipy mkl-service
conda install --yes -c conda-forge python-graphviz

pip install --upgrade pip

Expand Down