13
13
# limitations under the License.
14
14
15
15
16
+ from functools import partial
17
+
16
18
from pytensor .compile import SharedVariable
17
19
from pytensor .graph .basic import Constant , walk
18
20
from pytensor .tensor .basic import TensorVariable , Variable
@@ -55,7 +57,7 @@ def str_for_dist(
55
57
56
58
if "latex" in formatting :
57
59
if print_name is not None :
58
- print_name = r"\text{" + _latex_escape (dist . name .strip ("$" )) + "}"
60
+ print_name = r"\text{" + _latex_escape (print_name .strip ("$" )) + "}"
59
61
60
62
op_name = (
61
63
dist .owner .op ._print_name [1 ]
@@ -96,17 +98,16 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool
96
98
"""Make a human-readable string representation of Model, listing all random variables
97
99
and their distributions, optionally including parameter values."""
98
100
99
- kwargs = dict (formatting = formatting , include_params = include_params )
100
- free_rv_reprs = [str_for_dist (dist , ** kwargs ) for dist in model .free_RVs ]
101
- observed_rv_reprs = [str_for_dist (rv , ** kwargs ) for rv in model .observed_RVs ]
102
- det_reprs = [
103
- str_for_potential_or_deterministic (dist , ** kwargs , dist_name = "Deterministic" )
104
- for dist in model .deterministics
105
- ]
106
- potential_reprs = [
107
- str_for_potential_or_deterministic (pot , ** kwargs , dist_name = "Potential" )
108
- for pot in model .potentials
109
- ]
101
+ # Wrap functions to avoid confusing typecheckers
102
+ sfd = partial (str_for_dist , formatting = formatting , include_params = include_params )
103
+ sfp = partial (
104
+ str_for_potential_or_deterministic , formatting = formatting , include_params = include_params
105
+ )
106
+
107
+ free_rv_reprs = [sfd (dist ) for dist in model .free_RVs ]
108
+ observed_rv_reprs = [sfd (rv ) for rv in model .observed_RVs ]
109
+ det_reprs = [sfp (dist , dist_name = "Deterministic" ) for dist in model .deterministics ]
110
+ potential_reprs = [sfp (pot , dist_name = "Potential" ) for pot in model .potentials ]
110
111
111
112
var_reprs = free_rv_reprs + det_reprs + observed_rv_reprs + potential_reprs
112
113
@@ -162,6 +163,8 @@ def _str_for_input_var(var: Variable, formatting: str) -> str:
162
163
from pymc .distributions .distribution import SymbolicRandomVariable
163
164
164
165
def _is_potential_or_deterministic (var : Variable ) -> bool :
166
+ if not hasattr (var , "str_repr" ):
167
+ return False
165
168
try :
166
169
return var .str_repr .__func__ .func is str_for_potential_or_deterministic
167
170
except AttributeError :
@@ -175,14 +178,15 @@ def _is_potential_or_deterministic(var: Variable) -> bool:
175
178
) or _is_potential_or_deterministic (var ):
176
179
# show the names for RandomVariables, Deterministics, and Potentials, rather
177
180
# than the full expression
181
+ assert isinstance (var , TensorVariable )
178
182
return _str_for_input_rv (var , formatting )
179
183
elif isinstance (var .owner .op , DimShuffle ):
180
184
return _str_for_input_var (var .owner .inputs [0 ], formatting )
181
185
else :
182
186
return _str_for_expression (var , formatting )
183
187
184
188
185
- def _str_for_input_rv (var : Variable , formatting : str ) -> str :
189
+ def _str_for_input_rv (var : TensorVariable , formatting : str ) -> str :
186
190
_str = (
187
191
var .name
188
192
if var .name is not None
@@ -221,12 +225,15 @@ def _expand(x):
221
225
if x .owner and (not isinstance (x .owner .op , RandomVariable | SymbolicRandomVariable )):
222
226
return reversed (x .owner .inputs )
223
227
224
- parents = [
225
- x
226
- for x in walk (nodes = var .owner .inputs , expand = _expand )
227
- if x .owner and isinstance (x .owner .op , RandomVariable | SymbolicRandomVariable )
228
- ]
229
- names = [x .name for x in parents ]
228
+ parents = []
229
+ names = []
230
+ for x in walk (nodes = var .owner .inputs , expand = _expand ):
231
+ assert isinstance (x , Variable )
232
+ if x .owner and isinstance (x .owner .op , RandomVariable | SymbolicRandomVariable ):
233
+ parents .append (x )
234
+ xname = x .name
235
+ assert xname is not None
236
+ names .append (xname )
230
237
231
238
if "latex" in formatting :
232
239
return (
@@ -257,6 +264,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
257
264
"""Handy plug-in method to instruct IPython-like REPLs to use our str_repr above."""
258
265
# we know that our str_repr does not recurse, so we can ignore cycle
259
266
try :
267
+ if not hasattr (obj , "str_repr" ):
268
+ raise AttributeError
260
269
output = obj .str_repr ()
261
270
# Find newlines and replace them with p.break_()
262
271
# (see IPython.lib.pretty._repr_pprint)
0 commit comments