Skip to content

Commit f481347

Browse files
Add graphviz support for SymbolicRVs
1 parent ce89404 commit f481347

File tree

1 file changed

+48
-4
lines changed

1 file changed

+48
-4
lines changed

pymc/model_graph.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,37 @@
2121
from aesara.graph import Apply
2222
from aesara.graph.basic import ancestors, walk
2323
from aesara.scalar.basic import Cast
24+
from aesara.tensor.basic import get_scalar_constant_value
2425
from aesara.tensor.elemwise import Elemwise
2526
from aesara.tensor.random.op import RandomVariable
2627
from aesara.tensor.var import TensorConstant, TensorVariable
2728

2829
import pymc as pm
2930

31+
from pymc.distributions import Discrete
32+
from pymc.distributions.discrete import DiracDelta
3033
from pymc.util import get_default_varnames, get_var_name
3134

3235
VarName = NewType("VarName", str)
3336

3437

38+
def check_zip_graph_from_components(components):
39+
"""
40+
This helper function checks if a mixture sub-graph corresponds to a
41+
zero-inflated distribution using its components, a list of length two.
42+
"""
43+
if not any(isinstance(var.owner.op, DiracDelta) for var in components):
44+
return False
45+
46+
dirac_delta_idx = 1 - int(isinstance(components[0].owner.op, DiracDelta))
47+
dirac_delta = components[dirac_delta_idx]
48+
other_comp = components[1 - dirac_delta_idx]
49+
50+
return (get_scalar_constant_value(dirac_delta.owner.inputs[3]) == 0) and isinstance(
51+
other_comp.owner.op, Discrete
52+
)
53+
54+
3555
class ModelGraph:
3656
def __init__(self, model):
3757
self.model = model
@@ -154,16 +174,40 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
154174
shape = "box"
155175
style = "rounded, filled"
156176
label = f"{var_name}\n~\nMutableData"
157-
elif v.owner and isinstance(v.owner.op, RandomVariable):
177+
elif v.owner and (v in self.model.basic_RVs):
158178
shape = "ellipse"
159-
if hasattr(v.tag, "observations"):
179+
if v in self.model.observed_RVs:
160180
# observed RV
161181
style = "filled"
162182
else:
163-
shape = "ellipse"
164183
style = None
165184
symbol = v.owner.op.__class__.__name__
166-
if symbol.endswith("RV"):
185+
if symbol == "MarginalMixtureRV":
186+
components = v.owner.inputs[2:]
187+
if len(components) == 2:
188+
component_names = [
189+
var.owner.op.__class__.__name__.replace("Unmeasurable", "")[:-2]
190+
for var in components
191+
]
192+
if check_zip_graph_from_components(components):
193+
# ZeroInflated distribution
194+
component_names.remove("DiracDelta")
195+
symbol = f"ZeroInflated{component_names[0]}"
196+
else:
197+
# X-Y mixture
198+
symbol = f"{'-'.join(component_names)}Mixture"
199+
elif len(components) == 1:
200+
# single component dispatch mixture
201+
symbol = f"{components[0].owner.op.__class__.__name__.replace('Unmeasurable', '')[:-2]}Mixture"
202+
else:
203+
symbol = symbol[:-2] # just MarginalMixture
204+
elif symbol == "CensoredRV":
205+
censored_dist = v.owner.inputs[0]
206+
symbol = symbol[:-2] + censored_dist.owner.op.__class__.__name__[:-2]
207+
elif symbol == "TruncatedRV":
208+
truncated_dist = v.owner.op.base_rv_op
209+
symbol = symbol[:-2] + truncated_dist.__class__.__name__[:-2]
210+
elif symbol.endswith("RV"):
167211
symbol = symbol[:-2]
168212
label = f"{var_name}\n~\n{symbol}"
169213
else:

0 commit comments

Comments
 (0)