Skip to content

Commit cd89ad6

Browse files
committed
add pretty str/latex for Deterministic and Potential
1 parent a58ca64 commit cd89ad6

File tree

3 files changed

+69
-12
lines changed

3 files changed

+69
-12
lines changed

pymc3/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
resize_from_dims,
4242
resize_from_observed,
4343
)
44-
from pymc3.printing import str_repr
44+
from pymc3.printing import str_for_dist
4545
from pymc3.util import UNSET
4646
from pymc3.vartypes import string_types
4747

@@ -228,9 +228,9 @@ def __new__(
228228
)
229229

230230
# add in pretty-printing support
231-
rv_out.str_repr = types.MethodType(str_repr, rv_out)
231+
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
232232
rv_out._repr_latex_ = types.MethodType(
233-
functools.partial(str_repr, formatting="latex"), rv_out
233+
functools.partial(str_for_dist, formatting="latex"), rv_out
234234
)
235235

236236
return rv_out

pymc3/model.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -669,10 +669,12 @@ def __init__(
669669
self.deterministics = treelist()
670670
self.potentials = treelist()
671671

672-
from pymc3.printing import str_repr
672+
from pymc3.printing import str_for_model
673673

674-
self.str_repr = types.MethodType(str_repr, self)
675-
self._repr_latex_ = types.MethodType(functools.partial(str_repr, formatting="latex"), self)
674+
self.str_repr = types.MethodType(str_for_model, self)
675+
self._repr_latex_ = types.MethodType(
676+
functools.partial(str_for_model, formatting="latex"), self
677+
)
676678

677679
@property
678680
def model(self):
@@ -1787,6 +1789,13 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
17871789
model.deterministics.append(var)
17881790
model.add_random_variable(var, dims)
17891791

1792+
from pymc3.printing import str_for_deterministic
1793+
1794+
var.str_repr = types.MethodType(str_for_deterministic, var)
1795+
var._repr_latex_ = types.MethodType(
1796+
functools.partial(str_for_deterministic, formatting="latex"), var
1797+
)
1798+
17901799
return var
17911800

17921801

@@ -1807,4 +1816,12 @@ def Potential(name, var, model=None):
18071816
var.tag.scaling = None
18081817
model.potentials.append(var)
18091818
model.add_random_variable(var)
1819+
1820+
from pymc3.printing import str_for_potential
1821+
1822+
var.str_repr = types.MethodType(str_for_potential, var)
1823+
var._repr_latex_ = types.MethodType(
1824+
functools.partial(str_for_potential, formatting="latex"), var
1825+
)
1826+
18101827
return var

pymc3/printing.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import itertools
1616

17-
from functools import singledispatch
1817
from typing import Union
1918

2019
from aesara.graph.basic import walk
@@ -26,8 +25,7 @@
2625
from pymc3.model import Model
2726

2827

29-
@singledispatch
30-
def str_repr(rv: TensorVariable, formatting: str = "plain", include_params: bool = True) -> str:
28+
def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params: bool = True) -> str:
3129
"""Make a human-readable string representation of a RandomVariable in a model, either
3230
LaTeX or plain, optionally with distribution parameter values included."""
3331

@@ -51,11 +49,10 @@ def str_repr(rv: TensorVariable, formatting: str = "plain", include_params: bool
5149
return fr"{print_name} ~ {dist_name}"
5250

5351

54-
@str_repr.register
55-
def _(model: Model, formatting: str = "plain", include_params: bool = True) -> str:
52+
def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str:
5653
"""Make a human-readable string representation of Model, listing all random variables
5754
and their distributions, optionally including parameter values."""
58-
all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs)
55+
all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs, model.potentials)
5956

6057
rv_reprs = [rv.str_repr(formatting=formatting, include_params=include_params) for rv in all_rv]
6158
rv_reprs = [rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr]
@@ -84,6 +81,44 @@ def _(model: Model, formatting: str = "plain", include_params: bool = True) -> s
8481
return "\n".join(rv_reprs)
8582

8683

84+
def str_for_deterministic(
85+
var: TensorVariable, formatting: str = "plain", include_params: bool = True
86+
) -> str:
87+
print_name = var.name if var.name is not None else "<unnamed>"
88+
if "latex" in formatting:
89+
print_name = r"\text{" + _latex_escape(print_name) + "}"
90+
if include_params:
91+
return fr"${print_name} \sim Deterministic[{_str_for_expression(var, formatting=formatting)}]$"
92+
else:
93+
return fr"${print_name} \sim Deterministic$"
94+
else: # plain
95+
if include_params:
96+
return (
97+
fr"{print_name} ~ Deterministic[{_str_for_expression(var, formatting=formatting)}]"
98+
)
99+
else:
100+
return fr"{print_name} ~ Deterministic"
101+
102+
103+
def str_for_potential(
104+
var: TensorVariable, formatting: str = "plain", include_params: bool = True
105+
) -> str:
106+
print_name = var.name if var.name is not None else "<unnamed>"
107+
if "latex" in formatting:
108+
print_name = r"\text{" + _latex_escape(print_name) + "}"
109+
if include_params:
110+
return (
111+
fr"${print_name} \sim Potential[{_str_for_expression(var, formatting=formatting)}]$"
112+
)
113+
else:
114+
return fr"${print_name} \sim Potential$"
115+
else: # plain
116+
if include_params:
117+
return fr"{print_name} ~ Potential[{_str_for_expression(var, formatting=formatting)}]"
118+
else:
119+
return fr"{print_name} ~ Potential"
120+
121+
87122
def _str_for_input_var(var: Variable, formatting: str) -> str:
88123
# note we're dispatching both on type(var) and on type(var.owner.op) so cannot
89124
# use the standard functools.singledispatch
@@ -93,6 +128,11 @@ def _str_for_input_var(var: Variable, formatting: str) -> str:
93128
return _str_for_input_rv(var, formatting)
94129
elif isinstance(var.owner.op, DimShuffle):
95130
return _str_for_input_var(var.owner.inputs[0], formatting)
131+
elif hasattr(var, "str_repr") and (
132+
var.str_repr.__func__ is str_for_deterministic or var.str_repr.__func__ is str_for_potential
133+
):
134+
# display the name for a Deterministic or Potential, rather than the full expression
135+
return var.name
96136
else:
97137
return _str_for_expression(var, formatting)
98138

0 commit comments

Comments
 (0)