|
21 | 21 | from aesara.graph import Apply
|
22 | 22 | from aesara.graph.basic import ancestors, walk
|
23 | 23 | from aesara.scalar.basic import Cast
|
| 24 | +from aesara.tensor.basic import get_scalar_constant_value |
24 | 25 | from aesara.tensor.elemwise import Elemwise
|
25 | 26 | from aesara.tensor.random.op import RandomVariable
|
26 | 27 | from aesara.tensor.var import TensorConstant, TensorVariable
|
27 | 28 |
|
28 | 29 | import pymc as pm
|
29 | 30 |
|
| 31 | +from pymc.distributions import Discrete |
| 32 | +from pymc.distributions.discrete import DiracDelta |
30 | 33 | from pymc.util import get_default_varnames, get_var_name
|
31 | 34 |
|
32 | 35 | VarName = NewType("VarName", str)
|
33 | 36 |
|
34 | 37 |
|
| 38 | +def check_zip_graph_from_components(components): |
| 39 | + """ |
| 40 | + This helper function checks if a mixture sub-graph corresponds to a |
| 41 | + zero-inflated distribution using its components, a list of length two. |
| 42 | + """ |
| 43 | + if not any(isinstance(var.owner.op, DiracDelta) for var in components): |
| 44 | + return False |
| 45 | + |
| 46 | + dirac_delta_idx = 1 - int(isinstance(components[0].owner.op, DiracDelta)) |
| 47 | + dirac_delta = components[dirac_delta_idx] |
| 48 | + other_comp = components[1 - dirac_delta_idx] |
| 49 | + |
| 50 | + return (get_scalar_constant_value(dirac_delta.owner.inputs[3]) == 0) and isinstance( |
| 51 | + other_comp.owner.op, Discrete |
| 52 | + ) |
| 53 | + |
| 54 | + |
35 | 55 | class ModelGraph:
|
36 | 56 | def __init__(self, model):
|
37 | 57 | self.model = model
|
@@ -154,16 +174,40 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
|
154 | 174 | shape = "box"
|
155 | 175 | style = "rounded, filled"
|
156 | 176 | label = f"{var_name}\n~\nMutableData"
|
157 |
| - elif v.owner and isinstance(v.owner.op, RandomVariable): |
| 177 | + elif v.owner and (v in self.model.basic_RVs): |
158 | 178 | shape = "ellipse"
|
159 |
| - if hasattr(v.tag, "observations"): |
| 179 | + if v in self.model.observed_RVs: |
160 | 180 | # observed RV
|
161 | 181 | style = "filled"
|
162 | 182 | else:
|
163 |
| - shape = "ellipse" |
164 | 183 | style = None
|
165 | 184 | symbol = v.owner.op.__class__.__name__
|
166 |
| - if symbol.endswith("RV"): |
| 185 | + if symbol == "MarginalMixtureRV": |
| 186 | + components = v.owner.inputs[2:] |
| 187 | + if len(components) == 2: |
| 188 | + component_names = [ |
| 189 | + var.owner.op.__class__.__name__.replace("Unmeasurable", "")[:-2] |
| 190 | + for var in components |
| 191 | + ] |
| 192 | + if check_zip_graph_from_components(components): |
| 193 | + # ZeroInflated distribution |
| 194 | + component_names.remove("DiracDelta") |
| 195 | + symbol = f"ZeroInflated{component_names[0]}" |
| 196 | + else: |
| 197 | + # X-Y mixture |
| 198 | + symbol = f"{'-'.join(component_names)}Mixture" |
| 199 | + elif len(components) == 1: |
| 200 | + # single component dispatch mixture |
| 201 | + symbol = f"{components[0].owner.op.__class__.__name__.replace('Unmeasurable', '')[:-2]}Mixture" |
| 202 | + else: |
| 203 | + symbol = symbol[:-2] # just MarginalMixture |
| 204 | + elif symbol == "CensoredRV": |
| 205 | + censored_dist = v.owner.inputs[0] |
| 206 | + symbol = symbol[:-2] + censored_dist.owner.op.__class__.__name__[:-2] |
| 207 | + elif symbol == "TruncatedRV": |
| 208 | + truncated_dist = v.owner.op.base_rv_op |
| 209 | + symbol = symbol[:-2] + truncated_dist.__class__.__name__[:-2] |
| 210 | + elif symbol.endswith("RV"): |
167 | 211 | symbol = symbol[:-2]
|
168 | 212 | label = f"{var_name}\n~\n{symbol}"
|
169 | 213 | else:
|
|
0 commit comments