Skip to content

Commit c35c676

Browse files
Fixed printing of mixtures with single component dispatching
1 parent 5839b8d commit c35c676

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

pymc/printing.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
from typing import Union
1818

1919
from aesara.graph.basic import walk
20-
from aesara.tensor.basic import TensorVariable, Variable
20+
from aesara.tensor.basic import MakeVector, TensorVariable, Variable
2121
from aesara.tensor.elemwise import DimShuffle
2222
from aesara.tensor.random.basic import RandomVariable
2323
from aesara.tensor.var import TensorConstant
2424

2525
from pymc.model import Model
2626

27+
# from pymc.distributions.discrete import UnmeasurableConstantRV
28+
2729
__all__ = [
2830
"str_for_dist",
2931
"str_for_symbolic_dist",
@@ -61,11 +63,12 @@ def str_for_symbolic_dist(
6163
) -> str:
6264
def dispatch_comp_str(var, formatting=formatting, include_params=include_params):
6365
if var.name:
64-
return str_for_dist(var, formatting=formatting, include_params=include_params)
66+
return var.name
6567
if isinstance(var, TensorConstant):
66-
return _str_for_constant(var, formatting)
67-
if var.owner.op.name == "constant":
68-
return "hello"
68+
return _str_for_constant(var, formatting, print_vector=True)
69+
if isinstance(var.owner.op, MakeVector):
70+
# psi in some zero inflated distribution
71+
return dispatch_comp_str(var.owner.inputs[1])
6972

7073
# else it's a Mixture component initialized by the .dist() API
7174

@@ -81,27 +84,36 @@ def dispatch_comp_str(var, formatting=formatting, include_params=include_params)
8184

8285
if include_params:
8386
if "ZeroInflated" in rv.owner.op._print_name[0]:
84-
start_idx_para = 2
87+
# position 2 is just a constant_rv{0, (0,), shape, False}.1
88+
assert rv.owner.inputs[2].owner.op.__class__.__name__ == "UnmeasurableConstantRV"
89+
dist_parameters = [rv.owner.inputs[1]] + rv.owner.inputs[3:]
90+
8591
elif "Mixture" in rv.owner.op._print_name[0]:
86-
start_idx_para = 1
8792

8893
if len(rv.owner.inputs) == 3:
8994
# is a single component!
9095
# (rng, weights, single_component)
91-
pass
96+
rv.owner.op._print_name = (
97+
f"{rv.owner.inputs[2].owner.op.name.capitalize()}Mixture",
98+
"\\operatorname{" + f"{rv.owner.inputs[2].owner.op.name.capitalize()}Mixture}}",
99+
)
100+
dist_parameters = [rv.owner.inputs[1]] + rv.owner.inputs[2].owner.inputs[3:]
101+
else:
102+
dist_parameters = rv.owner.inputs[1:]
92103

93104
elif "Censored" in rv.owner.op._print_name[0]:
94-
start_idx_para = 2
105+
dist_parameters = rv.owner.inputs[2:]
95106
else:
96107
# Latex representation for the SymbolicDistribution has not been implemented.
97108
# Hoping for the best here!
98-
start_idx_para = 2
109+
dist_parameters = rv.owner.inputs[2:]
99110

100111
dist_args = [
101112
dispatch_comp_str(dist_para, formatting=formatting, include_params=include_params)
102-
for dist_para in rv.owner.inputs[start_idx_para:]
113+
for dist_para in dist_parameters
103114
]
104115

116+
# code below copied from str_for_dist
105117
print_name = rv.name if rv.name is not None else "<unnamed>"
106118
if "latex" in formatting:
107119
print_name = r"\text{" + _latex_escape(print_name) + "}"
@@ -210,11 +222,13 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
210222
return _str
211223

212224

213-
def _str_for_constant(var: TensorConstant, formatting: str) -> str:
225+
def _str_for_constant(var: TensorConstant, formatting: str, print_vector: bool = False) -> str:
214226
if len(var.data.shape) == 0:
215227
return f"{var.data:.3g}"
216228
elif len(var.data.shape) == 1 and var.data.shape[0] == 1:
217229
return f"{var.data[0]:.3g}"
230+
elif len(var.data.shape) == 1 and print_vector:
231+
return "[" + ", ".join([f"{const:.3g}" for const in var.data]) + "]"
218232
elif "latex" in formatting:
219233
return r"\text{<constant>}"
220234
else:

0 commit comments

Comments
 (0)