14
14
15
15
import itertools
16
16
17
- from functools import singledispatch
18
17
from typing import Union
19
18
20
19
from aesara .graph .basic import walk
26
25
from pymc3 .model import Model
27
26
28
27
29
- @singledispatch
30
- def str_repr (rv : TensorVariable , formatting : str = "plain" , include_params : bool = True ) -> str :
28
+ def str_for_dist (rv : TensorVariable , formatting : str = "plain" , include_params : bool = True ) -> str :
31
29
"""Make a human-readable string representation of a RandomVariable in a model, either
32
30
LaTeX or plain, optionally with distribution parameter values included."""
33
31
@@ -51,11 +49,10 @@ def str_repr(rv: TensorVariable, formatting: str = "plain", include_params: bool
51
49
return fr"{ print_name } ~ { dist_name } "
52
50
53
51
54
- @str_repr .register
55
- def _ (model : Model , formatting : str = "plain" , include_params : bool = True ) -> str :
52
+ def str_for_model (model : Model , formatting : str = "plain" , include_params : bool = True ) -> str :
56
53
"""Make a human-readable string representation of Model, listing all random variables
57
54
and their distributions, optionally including parameter values."""
58
- all_rv = itertools .chain (model .unobserved_RVs , model .observed_RVs )
55
+ all_rv = itertools .chain (model .unobserved_RVs , model .observed_RVs , model . potentials )
59
56
60
57
rv_reprs = [rv .str_repr (formatting = formatting , include_params = include_params ) for rv in all_rv ]
61
58
rv_reprs = [rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr ]
@@ -84,6 +81,44 @@ def _(model: Model, formatting: str = "plain", include_params: bool = True) -> s
84
81
return "\n " .join (rv_reprs )
85
82
86
83
84
+ def str_for_deterministic (
85
+ var : TensorVariable , formatting : str = "plain" , include_params : bool = True
86
+ ) -> str :
87
+ print_name = var .name if var .name is not None else "<unnamed>"
88
+ if "latex" in formatting :
89
+ print_name = r"\text{" + _latex_escape (print_name ) + "}"
90
+ if include_params :
91
+ return fr"${ print_name } \sim Deterministic[{ _str_for_expression (var , formatting = formatting )} ]$"
92
+ else :
93
+ return fr"${ print_name } \sim Deterministic$"
94
+ else : # plain
95
+ if include_params :
96
+ return (
97
+ fr"{ print_name } ~ Deterministic[{ _str_for_expression (var , formatting = formatting )} ]"
98
+ )
99
+ else :
100
+ return fr"{ print_name } ~ Deterministic"
101
+
102
+
103
+ def str_for_potential (
104
+ var : TensorVariable , formatting : str = "plain" , include_params : bool = True
105
+ ) -> str :
106
+ print_name = var .name if var .name is not None else "<unnamed>"
107
+ if "latex" in formatting :
108
+ print_name = r"\text{" + _latex_escape (print_name ) + "}"
109
+ if include_params :
110
+ return (
111
+ fr"${ print_name } \sim Potential[{ _str_for_expression (var , formatting = formatting )} ]$"
112
+ )
113
+ else :
114
+ return fr"${ print_name } \sim Potential$"
115
+ else : # plain
116
+ if include_params :
117
+ return fr"{ print_name } ~ Potential[{ _str_for_expression (var , formatting = formatting )} ]"
118
+ else :
119
+ return fr"{ print_name } ~ Potential"
120
+
121
+
87
122
def _str_for_input_var (var : Variable , formatting : str ) -> str :
88
123
# note we're dispatching both on type(var) and on type(var.owner.op) so cannot
89
124
# use the standard functools.singledispatch
@@ -93,6 +128,11 @@ def _str_for_input_var(var: Variable, formatting: str) -> str:
93
128
return _str_for_input_rv (var , formatting )
94
129
elif isinstance (var .owner .op , DimShuffle ):
95
130
return _str_for_input_var (var .owner .inputs [0 ], formatting )
131
+ elif hasattr (var , "str_repr" ) and (
132
+ var .str_repr .__func__ is str_for_deterministic or var .str_repr .__func__ is str_for_potential
133
+ ):
134
+ # display the name for a Deterministic or Potential, rather than the full expression
135
+ return var .name
96
136
else :
97
137
return _str_for_expression (var , formatting )
98
138
0 commit comments