|
24 | 24 | Any,
|
25 | 25 | Callable,
|
26 | 26 | Dict,
|
| 27 | + Iterable, |
27 | 28 | List,
|
28 | 29 | Literal,
|
29 | 30 | Optional,
|
|
65 | 66 | from pymc.initial_point import make_initial_point_fn
|
66 | 67 | from pymc.logprob.basic import transformed_conditional_logp
|
67 | 68 | from pymc.logprob.utils import ParameterValueError
|
| 69 | +from pymc.model_graph import VarName, model_to_graphviz |
68 | 70 | from pymc.pytensorf import (
|
69 | 71 | PointFunc,
|
70 | 72 | SeedSequenceSeed,
|
@@ -1879,6 +1881,53 @@ def debug_parameters(rv):
|
1879 | 1881 | elif not verbose:
|
1880 | 1882 | print_("You can set `verbose=True` for more details")
|
1881 | 1883 |
|
| 1884 | + def to_graphviz( |
| 1885 | + self, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain" |
| 1886 | + ): |
| 1887 | + """Produce a graphviz Digraph from a PyMC model. |
| 1888 | +
|
| 1889 | + Requires graphviz, which may be installed most easily with |
| 1890 | + conda install -c conda-forge python-graphviz |
| 1891 | +
|
| 1892 | + Alternatively, you may install the `graphviz` binaries yourself, |
| 1893 | + and then `pip install graphviz` to get the python bindings. See |
| 1894 | + http://graphviz.readthedocs.io/en/stable/manual.html |
| 1895 | + for more information. |
| 1896 | +
|
| 1897 | + Parameters |
| 1898 | + ---------- |
| 1899 | + var_names : iterable of variable names, optional |
| 1900 | + Subset of variables to be plotted that identify a subgraph with respect to the entire model graph |
| 1901 | + formatting : str, optional |
| 1902 | + one of { "plain" } |
| 1903 | +
|
| 1904 | + Examples |
| 1905 | + -------- |
| 1906 | + How to plot the graph of the model. |
| 1907 | +
|
| 1908 | + .. code-block:: python |
| 1909 | +
|
| 1910 | + import numpy as np |
| 1911 | + from pymc import HalfCauchy, Model, Normal, model_to_graphviz |
| 1912 | +
|
| 1913 | + J = 8 |
| 1914 | + y = np.array([28, 8, -3, 7, -1, 1, 18, 12]) |
| 1915 | + sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18]) |
| 1916 | +
|
| 1917 | + with Model() as schools: |
| 1918 | +
|
| 1919 | + eta = Normal("eta", 0, 1, shape=J) |
| 1920 | + mu = Normal("mu", 0, sigma=1e6) |
| 1921 | + tau = HalfCauchy("tau", 25) |
| 1922 | +
|
| 1923 | + theta = mu + tau * eta |
| 1924 | +
|
| 1925 | + obs = Normal("obs", theta, sigma=sigma, observed=y) |
| 1926 | +
|
| 1927 | + schools.to_graphviz() |
| 1928 | + """ |
| 1929 | + return model_to_graphviz(model=self, var_names=var_names, formatting=formatting) |
| 1930 | + |
1882 | 1931 |
|
1883 | 1932 | # this is really disgusting, but it breaks a self-loop: I can't pass Model
|
1884 | 1933 | # itself as context class init arg.
|
|
0 commit comments