Skip to content

Add graphviz support for SymbolicRVs #6149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,37 @@
from aesara.graph import Apply
from aesara.graph.basic import ancestors, walk
from aesara.scalar.basic import Cast
from aesara.tensor.basic import get_scalar_constant_value
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant, TensorVariable

import pymc as pm

from pymc.distributions import Discrete
from pymc.distributions.discrete import DiracDelta
from pymc.util import get_default_varnames, get_var_name

VarName = NewType("VarName", str)


def check_zip_graph_from_components(components):
"""
This helper function checks if a mixture sub-graph corresponds to a
zero-inflated distribution using its components, a list of length two.
"""
if not any(isinstance(var.owner.op, DiracDelta) for var in components):
return False

dirac_delta_idx = 1 - int(isinstance(components[0].owner.op, DiracDelta))
dirac_delta = components[dirac_delta_idx]
other_comp = components[1 - dirac_delta_idx]

return (get_scalar_constant_value(dirac_delta.owner.inputs[3]) == 0) and isinstance(
other_comp.owner.op, Discrete
)


class ModelGraph:
def __init__(self, model):
self.model = model
Expand Down Expand Up @@ -154,17 +174,42 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
shape = "box"
style = "rounded, filled"
label = f"{var_name}\n~\nMutableData"
elif v.owner and isinstance(v.owner.op, RandomVariable):
elif v.owner and (v in self.model.basic_RVs):
shape = "ellipse"
if hasattr(v.tag, "observations"):
if v in self.model.observed_RVs:
# observed RV
style = "filled"
else:
shape = "ellipse"
style = None
symbol = v.owner.op.__class__.__name__
if symbol.endswith("RV"):
symbol = symbol[:-2]
symbol = v.owner.op._print_name[0]
if symbol == "MarginalMixture":
components = v.owner.inputs[2:]
if len(components) == 2:
component_names = [var.owner.op._print_name[0] for var in components]
if check_zip_graph_from_components(components):
# ZeroInflated distribution
component_names.remove("DiracDelta")
symbol = f"ZeroInflated{component_names[0]}"
else:
# X-Y mixture
symbol = f"{'-'.join(component_names)}Mixture"
elif len(components) == 1:
# single component dispatch mixture
symbol = f"{components[0].owner.op._print_name[0]}Mixture"
else:
symbol = symbol[:-2] # just MarginalMixture
elif symbol == "Censored":
censored_dist = v.owner.inputs[0]
symbol = symbol + censored_dist.owner.op._print_name[0]
elif symbol == "Truncated":
truncated_dist = v.owner.op.base_rv_op
symbol = symbol + truncated_dist._print_name[0]
elif symbol == "RandomWalk":
innovation_dist = v.owner.inputs[1].owner.op._print_name[0]
if innovation_dist == "Normal":
symbol = "Gaussian" + symbol
else:
symbol = innovation_dist + symbol
label = f"{var_name}\n~\n{symbol}"
else:
shape = "box"
Expand Down
63 changes: 63 additions & 0 deletions pymc/tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@

import pymc as pm

from pymc.distributions import (
Cauchy,
Censored,
GaussianRandomWalk,
Mixture,
Normal,
RandomWalk,
StudentT,
Truncated,
ZeroInflatedPoisson,
)
from pymc.exceptions import ImputationWarning
from pymc.model_graph import ModelGraph, model_to_graphviz, model_to_networkx
from pymc.tests.helpers import SeededTest
Expand Down Expand Up @@ -360,3 +371,55 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
mg = ModelGraph(model_with_different_descendants())
assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
assert mg.make_compute_graph(var_names=var_names) == compute_graph


@pytest.mark.parametrize(
"symbolic_dist, dist_kwargs, display_name",
[
(ZeroInflatedPoisson, {"psi": 0.5, "mu": 5}, "ZeroInflatedPoisson"),
(
Censored,
{"dist": Normal.dist(Normal.dist(0.0, 5.0), 2.0), "lower": -2, "upper": 2},
"CensoredNormal",
),
(
Mixture,
{"w": [0.5, 0.5], "comp_dists": Normal.dist(0.0, 5.0, shape=(2,))},
"NormalMixture",
),
(
Mixture,
{"w": [0.5, 0.5], "comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0)]},
"Normal-StudentTMixture",
),
(
Mixture,
{
"w": [0.3, 0.45, 0.25],
"comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0), Cauchy.dist(1.0, 1.0)],
},
"MarginalMixture",
),
(
GaussianRandomWalk,
{"init_dist": Normal.dist(0.0, 5.0), "steps": 10},
"GaussianRandomWalk",
),
(Truncated, {"dist": StudentT.dist(7), "upper": 3.0}, "TruncatedStudentT"),
(
RandomWalk,
{
"innovation_dist": pm.StudentT.dist(7),
"init_dist": pm.Normal.dist(0, 1),
"steps": 10,
},
"StudentTRandomWalk",
),
],
)
def test_symbolic_distribution_display(symbolic_dist, dist_kwargs, display_name):
with pm.Model() as model:
symbolic_dist("x", **dist_kwargs)

graph = model_to_graphviz(model)
assert display_name in graph.source