Skip to content

Commit 09014b9

Browse files
authored
Add graphviz model graphs (#3049)
* Add graphviz model graphs * add --yes flag to conda * update tests * Comments
1 parent ed8b1dd commit 09014b9

File tree

7 files changed

+1438
-286
lines changed

7 files changed

+1438
-286
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
the unconditioned model.
2424
- SMC: remove experimental warning, allow sampling using `sample`, reduce autocorrelation from
2525
final trace.
26+
- Add `model_to_graphviz` (which uses the optional dependency `graphviz`) to
27+
plot a directed graph of a PyMC3 model using plate notation.
2628

2729
### Fixes
2830

docs/source/notebooks/multilevel_modeling.ipynb

Lines changed: 1180 additions & 286 deletions
Large diffs are not rendered by default.

pymc3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from . import gp
99
from .math import logaddexp, logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
1010
from .model import *
11+
from .model_graph import model_to_graphviz
1112
from .stats import *
1213
from .sampling import *
1314
from .step_methods import *

pymc3/model_graph.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from theano.gof.graph import inputs
2+
3+
from .util import get_default_varnames
4+
import pymc3 as pm
5+
6+
7+
class ModelGraph(object):
8+
def __init__(self, model):
9+
self.model = model
10+
self.var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
11+
self.var_list = self.model.named_vars.values()
12+
self.transform_map = {v.transformed: v.name for v in self.var_list if hasattr(v, 'transformed')}
13+
self._deterministics = None
14+
15+
def get_deterministics(self, var):
16+
"""Compute the deterministic nodes of the graph"""
17+
deterministics = []
18+
attrs = ('transformed', 'logpt')
19+
for v in self.var_list:
20+
if v != var and all(not hasattr(v, attr) for attr in attrs):
21+
deterministics.append(v)
22+
return deterministics
23+
24+
def _inputs(self, var, func, blockers=None):
25+
"""Get inputs to a function that are also named PyMC3 variables"""
26+
return set([j for j in inputs([func], blockers=blockers) if j in self.var_list and j != var])
27+
28+
def _get_inputs(self, var, func):
29+
"""Get all inputs to a function, doing some accounting for deterministics
30+
31+
Specifically, if a deterministic is an input, theano.gof.graph.inputs will
32+
return only the inputs *to the deterministic*. However, if we pass in the
33+
deterministic as a blocker, it will skip those nodes.
34+
"""
35+
deterministics = self.get_deterministics(var)
36+
upstream = self._inputs(var, func)
37+
parents = self._inputs(var, func, blockers=deterministics)
38+
if parents != upstream:
39+
det_map = {}
40+
for d in deterministics:
41+
d_set = {j for j in inputs([func], blockers=[d])}
42+
if upstream - d_set:
43+
det_map[d] = d_set
44+
for d, d_set in det_map.items():
45+
if all(d_set.issubset(other) for other in det_map.values()):
46+
parents.add(d)
47+
return parents
48+
49+
def _filter_parents(self, var, parents):
50+
"""Get direct parents of a var, as strings"""
51+
keep = set()
52+
for p in parents:
53+
if p == var:
54+
continue
55+
elif p.name in self.var_names:
56+
keep.add(p.name)
57+
elif p in self.transform_map:
58+
if self.transform_map[p] != var.name:
59+
keep.add(self.transform_map[p])
60+
else:
61+
raise AssertionError('Do not know what to do with {}'.format(str(p)))
62+
return keep
63+
64+
def get_parents(self, var):
65+
"""Get the named nodes that are direct inputs to the var"""
66+
if hasattr(var, 'transformed'):
67+
func = var.transformed.logpt
68+
elif hasattr(var, 'logpt'):
69+
func = var.logpt
70+
else:
71+
func = var
72+
73+
parents = self._get_inputs(var, func)
74+
return self._filter_parents(var, parents)
75+
76+
def make_compute_graph(self):
77+
"""Get map of var_name -> set(input var names) for the model"""
78+
input_map = {}
79+
for var_name in self.var_names:
80+
input_map[var_name] = self.get_parents(self.model[var_name])
81+
return input_map
82+
83+
def _make_node(self, var_name, graph):
84+
"""Attaches the given variable to a graphviz Digraph"""
85+
v = self.model[var_name]
86+
87+
# styling for node
88+
attrs = {}
89+
if isinstance(v, pm.model.ObservedRV):
90+
attrs['style'] = 'filled'
91+
92+
# Get name for node
93+
if hasattr(v, 'distribution'):
94+
distribution = v.distribution.__class__.__name__
95+
else:
96+
distribution = 'Deterministic'
97+
attrs['shape'] = 'box'
98+
99+
graph.node(var_name,
100+
'{var_name} ~ {distribution}'.format(var_name=var_name, distribution=distribution),
101+
**attrs)
102+
103+
def get_plates(self):
104+
""" Rough but surprisingly accurate plate detection.
105+
106+
Just groups by the shape of the underlying distribution. Will be wrong
107+
if there are two plates with the same shape.
108+
109+
Returns
110+
-------
111+
dict: str -> set[str]
112+
"""
113+
plates = {}
114+
for var_name in self.var_names:
115+
v = self.model[var_name]
116+
if hasattr(v, 'observations'):
117+
shape = v.observations.shape
118+
elif hasattr(v, 'dshape'):
119+
shape = v.dshape
120+
else:
121+
shape = v.tag.test_value.shape
122+
if shape == (1,):
123+
shape = tuple()
124+
if shape not in plates:
125+
plates[shape] = set()
126+
plates[shape].add(var_name)
127+
return plates
128+
129+
def make_graph(self):
130+
"""Make graphviz Digraph of PyMC3 model
131+
132+
Returns
133+
-------
134+
graphviz.Digraph
135+
"""
136+
try:
137+
import graphviz
138+
except ImportError:
139+
raise ImportError('This function requires the python library graphviz, along with binaries. '
140+
'The easiest way to install all of this is by running\n\n'
141+
'\tconda install -c conda-forge python-graphviz')
142+
graph = graphviz.Digraph(self.model.name)
143+
for shape, var_names in self.get_plates().items():
144+
label = ' x '.join(map('{:,d}'.format, shape))
145+
if label:
146+
# must be preceded by 'cluster' to get a box around it
147+
with graph.subgraph(name='cluster' + label) as sub:
148+
for var_name in var_names:
149+
self._make_node(var_name, sub)
150+
# plate label goes bottom right
151+
sub.attr(label=label, labeljust='r', labelloc='b', style='rounded')
152+
else:
153+
for var_name in var_names:
154+
self._make_node(var_name, graph)
155+
156+
for key, values in self.make_compute_graph().items():
157+
for value in values:
158+
graph.edge(value, key)
159+
return graph
160+
161+
162+
def model_to_graphviz(model=None):
163+
"""Produce a graphviz Digraph from a PyMC3 model.
164+
165+
Requires graphviz, which may be installed most easily with
166+
conda install -c conda-forge python-graphviz
167+
168+
Alternatively, you may install the `graphviz` binaries yourself,
169+
and then `pip install graphviz` to get the python bindings. See
170+
http://graphviz.readthedocs.io/en/stable/manual.html
171+
for more information.
172+
"""
173+
model = pm.modelcontext(model)
174+
return ModelGraph(model).make_graph()

pymc3/tests/test_model_graph.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
import pymc3 as pm
3+
from pymc3.model_graph import ModelGraph, model_to_graphviz
4+
5+
from .helpers import SeededTest
6+
7+
8+
def radon_model():
9+
"""Similar in shape to the Radon model"""
10+
n_homes = 919
11+
counties = 85
12+
uranium = np.random.normal(-.1, 0.4, size=n_homes)
13+
xbar = np.random.normal(1, 0.1, size=n_homes)
14+
floor_measure = np.random.randint(0, 2, size=n_homes)
15+
log_radon = np.random.normal(1, 1, size=n_homes)
16+
17+
d, r = divmod(919, 85)
18+
county = np.hstack((
19+
np.tile(np.arange(counties, dtype=int), d),
20+
np.arange(r)
21+
))
22+
with pm.Model() as model:
23+
sigma_a = pm.HalfCauchy('sigma_a', 5)
24+
gamma = pm.Normal('gamma', mu=0., sd=1e5, shape=3)
25+
mu_a = pm.Deterministic('mu_a', gamma[0] + gamma[1]*uranium + gamma[2]*xbar)
26+
eps_a = pm.Normal('eps_a', mu=0, sd=sigma_a, shape=counties)
27+
a = pm.Deterministic('a', mu_a + eps_a[county])
28+
b = pm.Normal('b', mu=0., sd=1e15)
29+
sigma_y = pm.Uniform('sigma_y', lower=0, upper=100)
30+
y_hat = a + b * floor_measure
31+
y_like = pm.Normal('y_like', mu=y_hat, sd=sigma_y, observed=log_radon)
32+
33+
compute_graph = {
34+
'sigma_a': set(),
35+
'gamma': set(),
36+
'mu_a': {'gamma'},
37+
'eps_a': {'sigma_a'},
38+
'a': {'mu_a', 'eps_a'},
39+
'b': set(),
40+
'sigma_y': set(),
41+
'y_like': {'a', 'b', 'sigma_y'}
42+
}
43+
plates = {
44+
(): {'b', 'sigma_a', 'sigma_y'},
45+
(3,): {'gamma'},
46+
(85,): {'eps_a'},
47+
(919,): {'a', 'mu_a', 'y_like'},
48+
}
49+
return model, compute_graph, plates
50+
51+
52+
class TestSimpleModel(SeededTest):
53+
@classmethod
54+
def setup_class(cls):
55+
cls.model, cls.compute_graph, cls.plates = radon_model()
56+
cls.model_graph = ModelGraph(cls.model)
57+
58+
def test_inputs(self):
59+
for child, parents in self.compute_graph.items():
60+
var = self.model[child]
61+
found_parents = self.model_graph.get_parents(var)
62+
assert found_parents == parents
63+
64+
def test_compute_graph(self):
65+
assert self.compute_graph == self.model_graph.make_compute_graph()
66+
67+
def test_plates(self):
68+
assert self.plates == self.model_graph.get_plates()
69+
70+
def test_graphviz(self):
71+
# just make sure everything runs without error
72+
73+
g = self.model_graph.make_graph()
74+
for key in self.compute_graph:
75+
assert key in g.source
76+
g = model_to_graphviz(self.model)
77+
for key in self.compute_graph:
78+
assert key in g.source
79+

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
bokeh>=0.12.13
22
CommonMark==0.5.4
3+
graphviz>=0.8.3
34
h5py>=2.7.0
45
ipython
56
Keras>=2.0.8

scripts/create_testenv.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ then
3434
fi
3535
fi
3636
conda install --yes numpy scipy mkl-service
37+
conda install --yes -c conda-forge python-graphviz
3738

3839
pip install --upgrade pip
3940

0 commit comments

Comments
 (0)