diff --git a/pymc3/model_graph.py b/pymc3/model_graph.py index 512b529ae5..80718fec56 100644 --- a/pymc3/model_graph.py +++ b/pymc3/model_graph.py @@ -161,6 +161,8 @@ def make_graph(self): '\tconda install -c conda-forge python-graphviz') graph = graphviz.Digraph(self.model.name) for shape, var_names in self.get_plates().items(): + if isinstance(shape, SharedVariable): + shape = shape.eval() label = ' x '.join(map('{:,d}'.format, shape)) if label: # must be preceded by 'cluster' to get a box around it diff --git a/pymc3/tests/test_model_graph.py b/pymc3/tests/test_model_graph.py index df4c4c0bba..609671b7b3 100644 --- a/pymc3/tests/test_model_graph.py +++ b/pymc3/tests/test_model_graph.py @@ -1,7 +1,7 @@ import numpy as np import pymc3 as pm +import theano as th from pymc3.model_graph import ModelGraph, model_to_graphviz - from .helpers import SeededTest @@ -14,6 +14,8 @@ def radon_model(): floor_measure = np.random.randint(0, 2, size=n_homes) log_radon = np.random.normal(1, 1, size=n_homes) + floor_measure = th.shared(floor_measure) + d, r = divmod(919, 85) county = np.hstack(( np.tile(np.arange(counties, dtype=int), d),