Skip to content

Commit 41e1e9d

Browse files
authored
Merge pull request #3491 from rpgoldman/graph-data
Mark `pm.Data` nodes in graphviz graphs.
2 parents 3fa26cf + 9450ec7 commit 41e1e9d

File tree

4 files changed

+51
-18
lines changed

4 files changed

+51
-18
lines changed

RELEASE-NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Release Notes
22

3+
## PyMC3 3.8 (on deck)
4+
5+
### New features
6+
- Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-defs/pymc3/pulls/3491).
7+
38
## PyMC3 3.7 (May 29 2019)
49

510
### New features

pymc3/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def align_minibatches(batches=None):
390390

391391
class Data:
392392
"""Data container class that wraps the theano SharedVariable class
393-
and let the model be aware of its inputs and outputs.
393+
and lets the model be aware of its inputs and outputs.
394394
395395
Parameters
396396
----------

pymc3/model_graph.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from collections import deque
2-
from typing import Iterator, Optional, MutableSet
2+
from typing import Dict, Iterator, Set, Optional
3+
4+
VarName = str
35

46
from theano.gof.graph import stack_search
57
from theano.compile import SharedVariable
68
from theano.tensor import Tensor
79

810
from .util import get_default_varnames
11+
from .model import ObservedRV
912
import pymc3 as pm
1013

11-
# this is a placeholder for a better characterization of the type
12-
# of variables in a model.
13-
RV = Tensor
14-
1514

1615
class ModelGraph:
1716
def __init__(self, model):
@@ -30,16 +29,16 @@ def get_deterministics(self, var):
3029
deterministics.append(v)
3130
return deterministics
3231

33-
def _get_ancestors(self, var, func) -> MutableSet[RV]:
32+
def _get_ancestors(self, var: Tensor, func) -> Set[Tensor]:
3433
"""Get all ancestors of a function, doing some accounting for deterministics.
3534
"""
3635

3736
# this contains all of the variables in the model EXCEPT var...
3837
vars = set(self.var_list)
3938
vars.remove(var)
4039

41-
blockers = set()
42-
retval = set()
40+
blockers = set() # type: Set[Tensor]
41+
retval = set() # type: Set[Tensor]
4342
def _expand(node) -> Optional[Iterator[Tensor]]:
4443
if node in blockers:
4544
return None
@@ -58,9 +57,9 @@ def _expand(node) -> Optional[Iterator[Tensor]]:
5857
mode='bfs')
5958
return retval
6059

61-
def _filter_parents(self, var, parents):
60+
def _filter_parents(self, var, parents) -> Set[VarName]:
6261
"""Get direct parents of a var, as strings"""
63-
keep = set()
62+
keep = set() # type: Set[VarName]
6463
for p in parents:
6564
if p == var:
6665
continue
@@ -73,7 +72,7 @@ def _filter_parents(self, var, parents):
7372
raise AssertionError('Do not know what to do with {}'.format(str(p)))
7473
return keep
7574

76-
def get_parents(self, var):
75+
def get_parents(self, var: Tensor) -> Set[VarName]:
7776
"""Get the named nodes that are direct inputs to the var"""
7877
if hasattr(var, 'transformed'):
7978
func = var.transformed.logpt
@@ -85,11 +84,26 @@ def get_parents(self, var):
8584
parents = self._get_ancestors(var, func)
8685
return self._filter_parents(var, parents)
8786

88-
def make_compute_graph(self):
87+
def make_compute_graph(self) -> Dict[str, Set[VarName]]:
8988
"""Get map of var_name -> set(input var names) for the model"""
90-
input_map = {}
89+
input_map = {} # type: Dict[str, Set[VarName]]
90+
def update_input_map(key: str, val: Set[VarName]):
91+
if key in input_map:
92+
input_map[key] = input_map[key].union(val)
93+
else:
94+
input_map[key] = val
95+
9196
for var_name in self.var_names:
92-
input_map[var_name] = self.get_parents(self.model[var_name])
97+
var = self.model[var_name]
98+
update_input_map(var_name, self.get_parents(var))
99+
if isinstance(var, ObservedRV):
100+
try:
101+
obs_name = var.observations.name
102+
if obs_name:
103+
input_map[var_name] = input_map[var_name].difference(set([obs_name]))
104+
update_input_map(obs_name, set([var_name]))
105+
except AttributeError:
106+
pass
93107
return input_map
94108

95109
def _make_node(self, var_name, graph):
@@ -101,12 +115,19 @@ def _make_node(self, var_name, graph):
101115
if isinstance(v, pm.model.ObservedRV):
102116
attrs['style'] = 'filled'
103117

118+
# make Data be roundtangle, instead of rectangle
104119
if isinstance(v, SharedVariable):
105-
attrs['style'] = 'filled'
120+
attrs['style'] = 'rounded, filled'
106121

107122
# Get name for node
108-
if hasattr(v, 'distribution'):
123+
if v in self.model.potentials:
124+
distribution = 'Potential'
125+
attrs['shape'] = 'octagon'
126+
elif hasattr(v, 'distribution'):
109127
distribution = v.distribution.__class__.__name__
128+
elif isinstance(v, SharedVariable):
129+
distribution = 'Data'
130+
attrs['shape'] = 'box'
110131
else:
111132
distribution = 'Deterministic'
112133
attrs['shape'] = 'box'

pymc3/tests/test_data_container.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,12 @@ def test_model_to_graphviz_for_model_with_data_container(self):
101101
pm.sample(1000, init=None, tune=1000, chains=1)
102102

103103
g = pm.model_to_graphviz(model)
104-
text = 'x [label="x ~ Deterministic" shape=box style=filled]'
104+
105+
# Data node rendered correctly?
106+
text = 'x [label="x ~ Data" shape=box style="rounded, filled"]'
107+
assert text in g.source
108+
# Didn't break ordinary variables?
109+
text = 'beta [label="beta ~ Normal"]'
110+
assert text in g.source
111+
text = 'obs [label="obs ~ Normal" style=filled]'
105112
assert text in g.source

0 commit comments

Comments
 (0)