Skip to content

Commit 8623cd4

Browse files
authored
LaTeX and pretty str representations for v4, continued (#4849)
* new simple pretty console printing for Model and RandomVariable * removing obsolete repr/latex code * add latex support to Model, use PrettyPrinter.break_ * more appropriate type hint (TensorVariable) * remove obsolete escape_latex * add pretty str/latex for Deterministic and Potential * update escape characters in latex formats * small refactor * refactor: unify str_for_potential_or_deterministic * refactor: safer fallback if user code changes str_repr * update tests, add __all__ * import printing in root module * use cloudpickle in smc sampling
1 parent 83c5a30 commit 8623cd4

File tree

10 files changed

+371
-321
lines changed

10 files changed

+371
-321
lines changed

pymc3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __set_compiler_flags():
6868
from pymc3.model import *
6969
from pymc3.model_graph import model_to_graphviz
7070
from pymc3.plots import *
71+
from pymc3.printing import *
7172
from pymc3.sampling import *
7273
from pymc3.smc import *
7374
from pymc3.stats import *

pymc3/distributions/bart.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -282,16 +282,3 @@ class BART(BaseBART):
282282

283283
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
284284
super().__init__(X, Y, m, alpha, split_prior)
285-
286-
def _str_repr(self, name=None, dist=None, formatting="plain"):
287-
if dist is None:
288-
dist = self
289-
X = (type(self.X),)
290-
Y = (type(self.Y),)
291-
alpha = self.alpha
292-
m = self.m
293-
294-
if "latex" in formatting:
295-
return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$"
296-
else:
297-
return f"{name} ~ BART(alpha = {alpha}, m = {m})"

pymc3/distributions/bound.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -143,25 +143,6 @@ def random(self, point=None, size=None):
143143
# )
144144
pass
145145

146-
def _distr_parameters_for_repr(self):
147-
return ["lower", "upper"]
148-
149-
def _distr_name_for_repr(self):
150-
return "Bound"
151-
152-
def _str_repr(self, **kwargs):
153-
distr_repr = self._wrapped._str_repr(**{**kwargs, "dist": self._wrapped})
154-
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
155-
distr_repr = distr_repr[distr_repr.index(r" \sim") + 6 :]
156-
else:
157-
distr_repr = distr_repr[distr_repr.index(" ~") + 3 :]
158-
self_repr = super()._str_repr(**kwargs)
159-
160-
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
161-
return self_repr + " -- " + distr_repr
162-
else:
163-
return self_repr + "-" + distr_repr
164-
165146

166147
class _DiscreteBounded(_Bounded, Discrete):
167148
def __init__(self, distribution, lower, upper, transform="infer", *args, **kwargs):

pymc3/distributions/distribution.py

Lines changed: 14 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import contextvars
15-
import inspect
15+
import functools
1616
import multiprocessing
1717
import sys
1818
import types
@@ -41,7 +41,8 @@
4141
resize_from_dims,
4242
resize_from_observed,
4343
)
44-
from pymc3.util import UNSET, get_repr_for_variable
44+
from pymc3.printing import str_for_dist
45+
from pymc3.util import UNSET
4546
from pymc3.vartypes import string_types
4647

4748
__all__ = [
@@ -222,7 +223,17 @@ def __new__(
222223
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
223224
rv_out.tag.test_value = initval
224225

225-
return model.register_rv(rv_out, name, observed, total_size, dims=dims, transform=transform)
226+
rv_out = model.register_rv(
227+
rv_out, name, observed, total_size, dims=dims, transform=transform
228+
)
229+
230+
# add in pretty-printing support
231+
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
232+
rv_out._repr_latex_ = types.MethodType(
233+
functools.partial(str_for_dist, formatting="latex"), rv_out
234+
)
235+
236+
return rv_out
226237

227238
@classmethod
228239
def dist(
@@ -313,77 +324,6 @@ def dist(
313324

314325
return rv_out
315326

316-
def _distr_parameters_for_repr(self):
317-
"""Return the names of the parameters for this distribution (e.g. "mu"
318-
and "sigma" for Normal). Used in generating string (and LaTeX etc.)
319-
representations of Distribution objects. By default based on inspection
320-
of __init__, but can be overwritten if necessary (e.g. to avoid including
321-
"sd" and "tau").
322-
"""
323-
return inspect.getfullargspec(self.__init__).args[1:]
324-
325-
def _distr_name_for_repr(self):
326-
return self.__class__.__name__
327-
328-
def _str_repr(self, name=None, dist=None, formatting="plain"):
329-
"""
330-
Generate string representation for this distribution, optionally
331-
including LaTeX markup (formatting='latex').
332-
333-
Parameters
334-
----------
335-
name : str
336-
name of the distribution
337-
dist : Distribution
338-
the distribution object
339-
formatting : str
340-
one of { "latex", "plain", "latex_with_params", "plain_with_params" }
341-
"""
342-
if dist is None:
343-
dist = self
344-
if name is None:
345-
name = "[unnamed]"
346-
supported_formattings = {"latex", "plain", "latex_with_params", "plain_with_params"}
347-
if not formatting in supported_formattings:
348-
raise ValueError(f"Unsupported formatting ''. Choose one of {supported_formattings}.")
349-
350-
param_names = self._distr_parameters_for_repr()
351-
param_values = [
352-
get_repr_for_variable(getattr(dist, x), formatting=formatting) for x in param_names
353-
]
354-
355-
if "latex" in formatting:
356-
param_string = ",~".join(
357-
[fr"\mathit{{{name}}}={value}" for name, value in zip(param_names, param_values)]
358-
)
359-
if formatting == "latex_with_params":
360-
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
361-
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
362-
)
363-
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}$".format(
364-
var_name=name, distr_name=dist._distr_name_for_repr()
365-
)
366-
else:
367-
# one of the plain formattings
368-
param_string = ", ".join(
369-
[f"{name}={value}" for name, value in zip(param_names, param_values)]
370-
)
371-
if formatting == "plain_with_params":
372-
return f"{name} ~ {dist._distr_name_for_repr()}({param_string})"
373-
return f"{name} ~ {dist._distr_name_for_repr()}"
374-
375-
def __str__(self, **kwargs):
376-
try:
377-
return self._str_repr(formatting="plain", **kwargs)
378-
except:
379-
return super().__str__()
380-
381-
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
382-
"""Magic method name for IPython to use for LaTeX formatting."""
383-
return self._str_repr(formatting=formatting, **kwargs)
384-
385-
__latex__ = _repr_latex_
386-
387327

388328
class NoDistribution(Distribution):
389329
def __init__(

pymc3/distributions/simulator.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,6 @@ def random(self, point=None, size=None):
121121
# else:
122122
# return np.array([self.function(*params) for _ in range(size[0])])
123123

124-
def _str_repr(self, name=None, dist=None, formatting="plain"):
125-
if dist is None:
126-
dist = self
127-
name = name
128-
function = dist.function.__name__
129-
params = ", ".join([var.name for var in dist.params])
130-
sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat
131-
distance = getattr(self.distance, "__name__", self.distance.__class__.__name__)
132-
133-
if "latex" in formatting:
134-
return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
135-
else:
136-
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"
137-
138124

139125
def identity(x):
140126
"""Identity function, used as a summary statistics."""

pymc3/model.py

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
import collections
16-
import itertools
16+
import functools
1717
import threading
18+
import types
1819
import warnings
1920

2021
from sys import modules
@@ -668,6 +669,13 @@ def __init__(
668669
self.deterministics = treelist()
669670
self.potentials = treelist()
670671

672+
from pymc3.printing import str_for_model
673+
674+
self.str_repr = types.MethodType(str_for_model, self)
675+
self._repr_latex_ = types.MethodType(
676+
functools.partial(str_for_model, formatting="latex"), self
677+
)
678+
671679
@property
672680
def model(self):
673681
return self
@@ -1628,46 +1636,6 @@ def point_logps(self, point=None, round_vals=2):
16281636
name="Log-probability of test_point",
16291637
)
16301638

1631-
def _str_repr(self, formatting="plain", **kwargs):
1632-
all_rv = itertools.chain(self.unobserved_RVs, self.observed_RVs)
1633-
1634-
if "latex" in formatting:
1635-
rv_reprs = [rv.__latex__(formatting=formatting) for rv in all_rv]
1636-
rv_reprs = [
1637-
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
1638-
for rv_repr in rv_reprs
1639-
if rv_repr is not None
1640-
]
1641-
return r"""$$
1642-
\begin{{array}}{{rcl}}
1643-
{}
1644-
\end{{array}}
1645-
$$""".format(
1646-
"\\\\".join(rv_reprs)
1647-
)
1648-
else:
1649-
rv_reprs = [rv.__str__() for rv in all_rv]
1650-
rv_reprs = [
1651-
rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr
1652-
]
1653-
# align vars on their ~
1654-
names = [s[: s.index("~") - 1] for s in rv_reprs]
1655-
distrs = [s[s.index("~") + 2 :] for s in rv_reprs]
1656-
maxlen = str(max(len(x) for x in names))
1657-
rv_reprs = [
1658-
("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d)
1659-
for n, d in zip(names, distrs)
1660-
]
1661-
return "\n".join(rv_reprs)
1662-
1663-
def __str__(self, **kwargs):
1664-
return self._str_repr(formatting="plain", **kwargs)
1665-
1666-
def _repr_latex_(self, *, formatting="latex", **kwargs):
1667-
return self._str_repr(formatting=formatting, **kwargs)
1668-
1669-
__latex__ = _repr_latex_
1670-
16711639

16721640
# this is really disgusting, but it breaks a self-loop: I can't pass Model
16731641
# itself as context class init arg.
@@ -1821,6 +1789,18 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
18211789
model.deterministics.append(var)
18221790
model.add_random_variable(var, dims)
18231791

1792+
from pymc3.printing import str_for_potential_or_deterministic
1793+
1794+
var.str_repr = types.MethodType(
1795+
functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
1796+
)
1797+
var._repr_latex_ = types.MethodType(
1798+
functools.partial(
1799+
str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
1800+
),
1801+
var,
1802+
)
1803+
18241804
return var
18251805

18261806

@@ -1841,4 +1821,17 @@ def Potential(name, var, model=None):
18411821
var.tag.scaling = None
18421822
model.potentials.append(var)
18431823
model.add_random_variable(var)
1824+
1825+
from pymc3.printing import str_for_potential_or_deterministic
1826+
1827+
var.str_repr = types.MethodType(
1828+
functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
1829+
)
1830+
var._repr_latex_ = types.MethodType(
1831+
functools.partial(
1832+
str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
1833+
),
1834+
var,
1835+
)
1836+
18441837
return var

0 commit comments

Comments
 (0)