|
| 1 | +from theano.gof.graph import inputs |
| 2 | + |
| 3 | +from .util import get_default_varnames |
| 4 | +import pymc3 as pm |
| 5 | + |
| 6 | + |
| 7 | +class ModelGraph(object): |
| 8 | + def __init__(self, model): |
| 9 | + self.model = model |
| 10 | + self.var_names = get_default_varnames(self.model.named_vars, include_transformed=False) |
| 11 | + self.var_list = self.model.named_vars.values() |
| 12 | + self.transform_map = {v.transformed: v.name for v in self.var_list if hasattr(v, 'transformed')} |
| 13 | + self._deterministics = None |
| 14 | + |
| 15 | + def get_deterministics(self, var): |
| 16 | + """Compute the deterministic nodes of the graph""" |
| 17 | + deterministics = [] |
| 18 | + attrs = ('transformed', 'logpt') |
| 19 | + for v in self.var_list: |
| 20 | + if v != var and all(not hasattr(v, attr) for attr in attrs): |
| 21 | + deterministics.append(v) |
| 22 | + return deterministics |
| 23 | + |
| 24 | + def _inputs(self, var, func, blockers=None): |
| 25 | + """Get inputs to a function that are also named PyMC3 variables""" |
| 26 | + return set([j for j in inputs([func], blockers=blockers) if j in self.var_list and j != var]) |
| 27 | + |
| 28 | + def _get_inputs(self, var, func): |
| 29 | + """Get all inputs to a function, doing some accounting for deterministics |
| 30 | +
|
| 31 | + Specifically, if a deterministic is an input, theano.gof.graph.inputs will |
| 32 | + return only the inputs *to the deterministic*. However, if we pass in the |
| 33 | + deterministic as a blocker, it will skip those nodes. |
| 34 | + """ |
| 35 | + deterministics = self.get_deterministics(var) |
| 36 | + upstream = self._inputs(var, func) |
| 37 | + parents = self._inputs(var, func, blockers=deterministics) |
| 38 | + if parents != upstream: |
| 39 | + det_map = {} |
| 40 | + for d in deterministics: |
| 41 | + d_set = {j for j in inputs([func], blockers=[d])} |
| 42 | + if upstream - d_set: |
| 43 | + det_map[d] = d_set |
| 44 | + for d, d_set in det_map.items(): |
| 45 | + if all(d_set.issubset(other) for other in det_map.values()): |
| 46 | + parents.add(d) |
| 47 | + return parents |
| 48 | + |
| 49 | + def _filter_parents(self, var, parents): |
| 50 | + """Get direct parents of a var, as strings""" |
| 51 | + keep = set() |
| 52 | + for p in parents: |
| 53 | + if p == var: |
| 54 | + continue |
| 55 | + elif p.name in self.var_names: |
| 56 | + keep.add(p.name) |
| 57 | + elif p in self.transform_map: |
| 58 | + if self.transform_map[p] != var.name: |
| 59 | + keep.add(self.transform_map[p]) |
| 60 | + else: |
| 61 | + raise AssertionError('Do not know what to do with {}'.format(str(p))) |
| 62 | + return keep |
| 63 | + |
| 64 | + def get_parents(self, var): |
| 65 | + """Get the named nodes that are direct inputs to the var""" |
| 66 | + if hasattr(var, 'transformed'): |
| 67 | + func = var.transformed.logpt |
| 68 | + elif hasattr(var, 'logpt'): |
| 69 | + func = var.logpt |
| 70 | + else: |
| 71 | + func = var |
| 72 | + |
| 73 | + parents = self._get_inputs(var, func) |
| 74 | + return self._filter_parents(var, parents) |
| 75 | + |
| 76 | + def make_compute_graph(self): |
| 77 | + """Get map of var_name -> set(input var names) for the model""" |
| 78 | + input_map = {} |
| 79 | + for var_name in self.var_names: |
| 80 | + input_map[var_name] = self.get_parents(self.model[var_name]) |
| 81 | + return input_map |
| 82 | + |
| 83 | + def _make_node(self, var_name, graph): |
| 84 | + """Attaches the given variable to a graphviz Digraph""" |
| 85 | + v = self.model[var_name] |
| 86 | + |
| 87 | + # styling for node |
| 88 | + attrs = {} |
| 89 | + if isinstance(v, pm.model.ObservedRV): |
| 90 | + attrs['style'] = 'filled' |
| 91 | + |
| 92 | + # Get name for node |
| 93 | + if hasattr(v, 'distribution'): |
| 94 | + distribution = v.distribution.__class__.__name__ |
| 95 | + else: |
| 96 | + distribution = 'Deterministic' |
| 97 | + attrs['shape'] = 'box' |
| 98 | + |
| 99 | + graph.node(var_name, |
| 100 | + '{var_name} ~ {distribution}'.format(var_name=var_name, distribution=distribution), |
| 101 | + **attrs) |
| 102 | + |
| 103 | + def get_plates(self): |
| 104 | + """ Rough but surprisingly accurate plate detection. |
| 105 | +
|
| 106 | + Just groups by the shape of the underlying distribution. Will be wrong |
| 107 | + if there are two plates with the same shape. |
| 108 | +
|
| 109 | + Returns |
| 110 | + ------- |
| 111 | + dict: str -> set[str] |
| 112 | + """ |
| 113 | + plates = {} |
| 114 | + for var_name in self.var_names: |
| 115 | + v = self.model[var_name] |
| 116 | + if hasattr(v, 'observations'): |
| 117 | + shape = v.observations.shape |
| 118 | + elif hasattr(v, 'dshape'): |
| 119 | + shape = v.dshape |
| 120 | + else: |
| 121 | + shape = v.tag.test_value.shape |
| 122 | + if shape == (1,): |
| 123 | + shape = tuple() |
| 124 | + if shape not in plates: |
| 125 | + plates[shape] = set() |
| 126 | + plates[shape].add(var_name) |
| 127 | + return plates |
| 128 | + |
| 129 | + def make_graph(self): |
| 130 | + """Make graphviz Digraph of PyMC3 model |
| 131 | +
|
| 132 | + Returns |
| 133 | + ------- |
| 134 | + graphviz.Digraph |
| 135 | + """ |
| 136 | + try: |
| 137 | + import graphviz |
| 138 | + except ImportError: |
| 139 | + raise ImportError('This function requires the python library graphviz, along with binaries. ' |
| 140 | + 'The easiest way to install all of this is by running\n\n' |
| 141 | + '\tconda install -c conda-forge python-graphviz') |
| 142 | + graph = graphviz.Digraph(self.model.name) |
| 143 | + for shape, var_names in self.get_plates().items(): |
| 144 | + label = ' x '.join(map('{:,d}'.format, shape)) |
| 145 | + if label: |
| 146 | + # must be preceded by 'cluster' to get a box around it |
| 147 | + with graph.subgraph(name='cluster' + label) as sub: |
| 148 | + for var_name in var_names: |
| 149 | + self._make_node(var_name, sub) |
| 150 | + # plate label goes bottom right |
| 151 | + sub.attr(label=label, labeljust='r', labelloc='b', style='rounded') |
| 152 | + else: |
| 153 | + for var_name in var_names: |
| 154 | + self._make_node(var_name, graph) |
| 155 | + |
| 156 | + for key, values in self.make_compute_graph().items(): |
| 157 | + for value in values: |
| 158 | + graph.edge(value, key) |
| 159 | + return graph |
| 160 | + |
| 161 | + |
| 162 | +def model_to_graphviz(model=None): |
| 163 | + """Produce a graphviz Digraph from a PyMC3 model. |
| 164 | +
|
| 165 | + Requires graphviz, which may be installed most easily with |
| 166 | + conda install -c conda-forge python-graphviz |
| 167 | +
|
| 168 | + Alternatively, you may install the `graphviz` binaries yourself, |
| 169 | + and then `pip install graphviz` to get the python bindings. See |
| 170 | + http://graphviz.readthedocs.io/en/stable/manual.html |
| 171 | + for more information. |
| 172 | + """ |
| 173 | + model = pm.modelcontext(model) |
| 174 | + return ModelGraph(model).make_graph() |
0 commit comments