17
17
from typing import Union
18
18
19
19
from aesara .graph .basic import walk
20
- from aesara .tensor .basic import TensorVariable , Variable
20
+ from aesara .tensor .basic import MakeVector , TensorVariable , Variable
21
21
from aesara .tensor .elemwise import DimShuffle
22
22
from aesara .tensor .random .basic import RandomVariable
23
23
from aesara .tensor .var import TensorConstant
24
24
25
25
from pymc .model import Model
26
26
27
+ # from pymc.distributions.discrete import UnmeasurableConstantRV
28
+
27
29
__all__ = [
28
30
"str_for_dist" ,
29
31
"str_for_symbolic_dist" ,
@@ -61,11 +63,12 @@ def str_for_symbolic_dist(
61
63
) -> str :
62
64
def dispatch_comp_str (var , formatting = formatting , include_params = include_params ):
63
65
if var .name :
64
- return str_for_dist ( var , formatting = formatting , include_params = include_params )
66
+ return var . name
65
67
if isinstance (var , TensorConstant ):
66
- return _str_for_constant (var , formatting )
67
- if var .owner .op .name == "constant" :
68
- return "hello"
68
+ return _str_for_constant (var , formatting , print_vector = True )
69
+ if isinstance (var .owner .op , MakeVector ):
70
+ # psi in some zero inflated distribution
71
+ return dispatch_comp_str (var .owner .inputs [1 ])
69
72
70
73
# else it's a Mixture component initialized by the .dist() API
71
74
@@ -81,27 +84,36 @@ def dispatch_comp_str(var, formatting=formatting, include_params=include_params)
81
84
82
85
if include_params :
83
86
if "ZeroInflated" in rv .owner .op ._print_name [0 ]:
84
- start_idx_para = 2
87
+ # position 2 is just a constant_rv{0, (0,), shape, False}.1
88
+ assert rv .owner .inputs [2 ].owner .op .__class__ .__name__ == "UnmeasurableConstantRV"
89
+ dist_parameters = [rv .owner .inputs [1 ]] + rv .owner .inputs [3 :]
90
+
85
91
elif "Mixture" in rv .owner .op ._print_name [0 ]:
86
- start_idx_para = 1
87
92
88
93
if len (rv .owner .inputs ) == 3 :
89
94
# is a single component!
90
95
# (rng, weights, single_component)
91
- pass
96
+ rv .owner .op ._print_name = (
97
+ f"{ rv .owner .inputs [2 ].owner .op .name .capitalize ()} Mixture" ,
98
+ "\\ operatorname{" + f"{ rv .owner .inputs [2 ].owner .op .name .capitalize ()} Mixture}}" ,
99
+ )
100
+ dist_parameters = [rv .owner .inputs [1 ]] + rv .owner .inputs [2 ].owner .inputs [3 :]
101
+ else :
102
+ dist_parameters = rv .owner .inputs [1 :]
92
103
93
104
elif "Censored" in rv .owner .op ._print_name [0 ]:
94
- start_idx_para = 2
105
+ dist_parameters = rv . owner . inputs [ 2 :]
95
106
else :
96
107
# Latex representation for the SymbolicDistribution has not been implemented.
97
108
# Hoping for the best here!
98
- start_idx_para = 2
109
+ dist_parameters = rv . owner . inputs [ 2 :]
99
110
100
111
dist_args = [
101
112
dispatch_comp_str (dist_para , formatting = formatting , include_params = include_params )
102
- for dist_para in rv . owner . inputs [ start_idx_para :]
113
+ for dist_para in dist_parameters
103
114
]
104
115
116
+ # code below copied from str_for_dist
105
117
print_name = rv .name if rv .name is not None else "<unnamed>"
106
118
if "latex" in formatting :
107
119
print_name = r"\text{" + _latex_escape (print_name ) + "}"
@@ -210,11 +222,13 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
210
222
return _str
211
223
212
224
213
- def _str_for_constant (var : TensorConstant , formatting : str ) -> str :
225
+ def _str_for_constant (var : TensorConstant , formatting : str , print_vector : bool = False ) -> str :
214
226
if len (var .data .shape ) == 0 :
215
227
return f"{ var .data :.3g} "
216
228
elif len (var .data .shape ) == 1 and var .data .shape [0 ] == 1 :
217
229
return f"{ var .data [0 ]:.3g} "
230
+ elif len (var .data .shape ) == 1 and print_vector :
231
+ return "[" + ", " .join ([f"{ const :.3g} " for const in var .data ]) + "]"
218
232
elif "latex" in formatting :
219
233
return r"\text{<constant>}"
220
234
else :
0 commit comments