@@ -81,54 +81,36 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool
81
81
return "\n " .join (rv_reprs )
82
82
83
83
84
- def str_for_deterministic (
85
- var : TensorVariable , formatting : str = "plain" , include_params : bool = True
84
+ def str_for_potential_or_deterministic (
85
+ var : TensorVariable , dist_name : str , formatting : str = "plain" , include_params : bool = True
86
86
) -> str :
87
87
print_name = var .name if var .name is not None else "<unnamed>"
88
88
if "latex" in formatting :
89
89
print_name = r"\text{" + _latex_escape (print_name ) + "}"
90
90
if include_params :
91
- return fr"${ print_name } \sim \operatorname{{Deterministic}}[ { _str_for_expression (var , formatting = formatting )} ] $"
91
+ return fr"${ print_name } \sim \operatorname{{{ dist_name } }}( { _str_for_expression (var , formatting = formatting )} ) $"
92
92
else :
93
- return fr"${ print_name } \sim \operatorname{{Deterministic }}$"
93
+ return fr"${ print_name } \sim \operatorname{{{ dist_name } }}$"
94
94
else : # plain
95
95
if include_params :
96
- return (
97
- fr"{ print_name } ~ Deterministic[{ _str_for_expression (var , formatting = formatting )} ]"
98
- )
96
+ return fr"{ print_name } ~ { dist_name } ({ _str_for_expression (var , formatting = formatting )} )"
99
97
else :
100
- return fr"{ print_name } ~ Deterministic"
101
-
102
-
103
- def str_for_potential (
104
- var : TensorVariable , formatting : str = "plain" , include_params : bool = True
105
- ) -> str :
106
- print_name = var .name if var .name is not None else "<unnamed>"
107
- if "latex" in formatting :
108
- print_name = r"\text{" + _latex_escape (print_name ) + "}"
109
- if include_params :
110
- return fr"${ print_name } \sim \operatorname{{Potential}}[{ _str_for_expression (var , formatting = formatting )} ]$"
111
- else :
112
- return fr"${ print_name } \sim \operatorname{{Potential}}$"
113
- else : # plain
114
- if include_params :
115
- return fr"{ print_name } ~ Potential[{ _str_for_expression (var , formatting = formatting )} ]"
116
- else :
117
- return fr"{ print_name } ~ Potential"
98
+ return fr"{ print_name } ~ { dist_name } "
118
99
119
100
120
101
def _str_for_input_var (var : Variable , formatting : str ) -> str :
121
102
# note we're dispatching both on type(var) and on type(var.owner.op) so cannot
122
103
# use the standard functools.singledispatch
104
+
105
+ def _is_potential_or_determinstic (var : Variable ) -> bool :
106
+ return (
107
+ hasattr (var , "str_repr" )
108
+ and var .str_repr .__func__ .func is str_for_potential_or_deterministic
109
+ )
110
+
123
111
if isinstance (var , TensorConstant ):
124
112
return _str_for_constant (var , formatting )
125
- elif isinstance (var .owner .op , RandomVariable ) or (
126
- hasattr (var , "str_repr" )
127
- and (
128
- var .str_repr .__func__ is str_for_deterministic
129
- or var .str_repr .__func__ is str_for_potential
130
- )
131
- ):
113
+ elif isinstance (var .owner .op , RandomVariable ) or _is_potential_or_determinstic (var ):
132
114
# show the names for RandomVariables, Deterministics, and Potentials, rather
133
115
# than the full expression
134
116
return _str_for_input_rv (var , formatting )
0 commit comments