Skip to content

Commit ddd1d4b

Browse files
Juan OrduzmichaelosthegericardoV94
authored
Add Model.to_graphviz shortcut (#6865)
* add methods * add test structure * fix docstrings * remove networkx method * imptove tests * rename method name * add mock test * Update tests/test_model.py Co-authored-by: Michael Osthege <[email protected]> * update docsting Co-authored-by: Ricardo Vieira <[email protected]> --------- Co-authored-by: Michael Osthege <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 6691943 commit ddd1d4b

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

pymc/model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Any,
2525
Callable,
2626
Dict,
27+
Iterable,
2728
List,
2829
Literal,
2930
Optional,
@@ -65,6 +66,7 @@
6566
from pymc.initial_point import make_initial_point_fn
6667
from pymc.logprob.basic import transformed_conditional_logp
6768
from pymc.logprob.utils import ParameterValueError
69+
from pymc.model_graph import VarName, model_to_graphviz
6870
from pymc.pytensorf import (
6971
PointFunc,
7072
SeedSequenceSeed,
@@ -1879,6 +1881,53 @@ def debug_parameters(rv):
18791881
elif not verbose:
18801882
print_("You can set `verbose=True` for more details")
18811883

1884+
def to_graphviz(
1885+
self, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
1886+
):
1887+
"""Produce a graphviz Digraph from a PyMC model.
1888+
1889+
Requires graphviz, which may be installed most easily with
1890+
conda install -c conda-forge python-graphviz
1891+
1892+
Alternatively, you may install the `graphviz` binaries yourself,
1893+
and then `pip install graphviz` to get the python bindings. See
1894+
http://graphviz.readthedocs.io/en/stable/manual.html
1895+
for more information.
1896+
1897+
Parameters
1898+
----------
1899+
var_names : iterable of variable names, optional
1900+
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
1901+
formatting : str, optional
1902+
one of { "plain" }
1903+
1904+
Examples
1905+
--------
1906+
How to plot the graph of the model.
1907+
1908+
.. code-block:: python
1909+
1910+
import numpy as np
1911+
from pymc import HalfCauchy, Model, Normal, model_to_graphviz
1912+
1913+
J = 8
1914+
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
1915+
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
1916+
1917+
with Model() as schools:
1918+
1919+
eta = Normal("eta", 0, 1, shape=J)
1920+
mu = Normal("mu", 0, sigma=1e6)
1921+
tau = HalfCauchy("tau", 25)
1922+
1923+
theta = mu + tau * eta
1924+
1925+
obs = Normal("obs", theta, sigma=sigma, observed=y)
1926+
1927+
schools.to_graphviz()
1928+
"""
1929+
return model_to_graphviz(model=self, var_names=var_names, formatting=formatting)
1930+
18821931

18831932
# this is really disgusting, but it breaks a self-loop: I can't pass Model
18841933
# itself as context class init arg.

tests/test_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import unittest
1818
import warnings
1919

20+
from unittest.mock import MagicMock, patch
21+
2022
import arviz as az
2123
import cloudpickle
2224
import numpy as np
@@ -47,6 +49,7 @@
4749
from pymc.logprob.basic import conditional_logp, transformed_conditional_logp
4850
from pymc.logprob.transforms import IntervalTransform
4951
from pymc.model import Point, ValueGradFunction, modelcontext
52+
from pymc.model_graph import model_to_graphviz
5053
from pymc.util import _FutureWarningValidatingScratchpad
5154
from pymc.variational.minibatch_rv import MinibatchRandomVariable
5255
from tests.models import simple_model
@@ -1653,3 +1656,28 @@ def test_model_logp_fast_compile():
16531656

16541657
with pytensor.config.change_flags(mode="FAST_COMPILE"):
16551658
assert m.point_logps() == {"a": -1.5}
1659+
1660+
1661+
class TestModelGraphs:
1662+
@staticmethod
1663+
def school_model(J: int) -> pm.Model:
1664+
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
1665+
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
1666+
with pm.Model(coords={"school": np.arange(J)}) as schools:
1667+
eta = pm.Normal("eta", 0, 1, dims="school")
1668+
mu = pm.Normal("mu", 0, sigma=1e6)
1669+
tau = pm.HalfCauchy("tau", 25)
1670+
theta = mu + tau * eta
1671+
pm.Normal("obs", theta, sigma=sigma, observed=y, dims="school")
1672+
return schools
1673+
1674+
@pytest.mark.parametrize(
1675+
argnames="var_names", argvalues=[None, ["mu", "tau"]], ids=["all", "subset"]
1676+
)
1677+
def test_graphviz_call_function(self, var_names) -> None:
1678+
model = self.school_model(J=8)
1679+
with patch("pymc.model.model_to_graphviz") as mock_model_to_graphviz:
1680+
model.to_graphviz(var_names=var_names)
1681+
mock_model_to_graphviz.assert_called_once_with(
1682+
model=model, var_names=var_names, formatting="plain"
1683+
)

0 commit comments

Comments
 (0)