Skip to content

Commit 75045b3

Browse files
committed
refactor: unify str_for_potential_or_deterministic
1 parent 7f6bc9a commit 75045b3

File tree

2 files changed

+30
-38
lines changed

2 files changed

+30
-38
lines changed

pymc3/model.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,11 +1789,16 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
17891789
model.deterministics.append(var)
17901790
model.add_random_variable(var, dims)
17911791

1792-
from pymc3.printing import str_for_deterministic
1792+
from pymc3.printing import str_for_potential_or_deterministic
17931793

1794-
var.str_repr = types.MethodType(str_for_deterministic, var)
1794+
var.str_repr = types.MethodType(
1795+
functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
1796+
)
17951797
var._repr_latex_ = types.MethodType(
1796-
functools.partial(str_for_deterministic, formatting="latex"), var
1798+
functools.partial(
1799+
str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
1800+
),
1801+
var,
17971802
)
17981803

17991804
return var
@@ -1817,11 +1822,16 @@ def Potential(name, var, model=None):
18171822
model.potentials.append(var)
18181823
model.add_random_variable(var)
18191824

1820-
from pymc3.printing import str_for_potential
1825+
from pymc3.printing import str_for_potential_or_deterministic
18211826

1822-
var.str_repr = types.MethodType(str_for_potential, var)
1827+
var.str_repr = types.MethodType(
1828+
functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
1829+
)
18231830
var._repr_latex_ = types.MethodType(
1824-
functools.partial(str_for_potential, formatting="latex"), var
1831+
functools.partial(
1832+
str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
1833+
),
1834+
var,
18251835
)
18261836

18271837
return var

pymc3/printing.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -81,54 +81,36 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool
8181
return "\n".join(rv_reprs)
8282

8383

84-
def str_for_deterministic(
85-
var: TensorVariable, formatting: str = "plain", include_params: bool = True
84+
def str_for_potential_or_deterministic(
85+
var: TensorVariable, dist_name: str, formatting: str = "plain", include_params: bool = True
8686
) -> str:
8787
print_name = var.name if var.name is not None else "<unnamed>"
8888
if "latex" in formatting:
8989
print_name = r"\text{" + _latex_escape(print_name) + "}"
9090
if include_params:
91-
return fr"${print_name} \sim \operatorname{{Deterministic}}[{_str_for_expression(var, formatting=formatting)}]$"
91+
return fr"${print_name} \sim \operatorname{{{dist_name}}}({_str_for_expression(var, formatting=formatting)})$"
9292
else:
93-
return fr"${print_name} \sim \operatorname{{Deterministic}}$"
93+
return fr"${print_name} \sim \operatorname{{{dist_name}}}$"
9494
else: # plain
9595
if include_params:
96-
return (
97-
fr"{print_name} ~ Deterministic[{_str_for_expression(var, formatting=formatting)}]"
98-
)
96+
return fr"{print_name} ~ {dist_name}({_str_for_expression(var, formatting=formatting)})"
9997
else:
100-
return fr"{print_name} ~ Deterministic"
101-
102-
103-
def str_for_potential(
104-
var: TensorVariable, formatting: str = "plain", include_params: bool = True
105-
) -> str:
106-
print_name = var.name if var.name is not None else "<unnamed>"
107-
if "latex" in formatting:
108-
print_name = r"\text{" + _latex_escape(print_name) + "}"
109-
if include_params:
110-
return fr"${print_name} \sim \operatorname{{Potential}}[{_str_for_expression(var, formatting=formatting)}]$"
111-
else:
112-
return fr"${print_name} \sim \operatorname{{Potential}}$"
113-
else: # plain
114-
if include_params:
115-
return fr"{print_name} ~ Potential[{_str_for_expression(var, formatting=formatting)}]"
116-
else:
117-
return fr"{print_name} ~ Potential"
98+
return fr"{print_name} ~ {dist_name}"
11899

119100

120101
def _str_for_input_var(var: Variable, formatting: str) -> str:
121102
# note we're dispatching both on type(var) and on type(var.owner.op) so cannot
122103
# use the standard functools.singledispatch
104+
105+
def _is_potential_or_determinstic(var: Variable) -> bool:
106+
return (
107+
hasattr(var, "str_repr")
108+
and var.str_repr.__func__.func is str_for_potential_or_deterministic
109+
)
110+
123111
if isinstance(var, TensorConstant):
124112
return _str_for_constant(var, formatting)
125-
elif isinstance(var.owner.op, RandomVariable) or (
126-
hasattr(var, "str_repr")
127-
and (
128-
var.str_repr.__func__ is str_for_deterministic
129-
or var.str_repr.__func__ is str_for_potential
130-
)
131-
):
113+
elif isinstance(var.owner.op, RandomVariable) or _is_potential_or_determinstic(var):
132114
# show the names for RandomVariables, Deterministics, and Potentials, rather
133115
# than the full expression
134116
return _str_for_input_rv(var, formatting)

0 commit comments

Comments
 (0)