Skip to content

LaTeX and pretty str representations for v4, continued #4849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 13, 2021
Merged
1 change: 1 addition & 0 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __set_compiler_flags():
from pymc3.model import *
from pymc3.model_graph import model_to_graphviz
from pymc3.plots import *
from pymc3.printing import *
from pymc3.sampling import *
from pymc3.smc import *
from pymc3.stats import *
Expand Down
13 changes: 0 additions & 13 deletions pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,3 @@ class BART(BaseBART):

def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
super().__init__(X, Y, m, alpha, split_prior)

def _str_repr(self, name=None, dist=None, formatting="plain"):
if dist is None:
dist = self
X = (type(self.X),)
Y = (type(self.Y),)
alpha = self.alpha
m = self.m

if "latex" in formatting:
return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$"
else:
return f"{name} ~ BART(alpha = {alpha}, m = {m})"
19 changes: 0 additions & 19 deletions pymc3/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,6 @@ def random(self, point=None, size=None):
# )
pass

def _distr_parameters_for_repr(self):
return ["lower", "upper"]

def _distr_name_for_repr(self):
return "Bound"

def _str_repr(self, **kwargs):
distr_repr = self._wrapped._str_repr(**{**kwargs, "dist": self._wrapped})
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
distr_repr = distr_repr[distr_repr.index(r" \sim") + 6 :]
else:
distr_repr = distr_repr[distr_repr.index(" ~") + 3 :]
self_repr = super()._str_repr(**kwargs)

if "formatting" in kwargs and "latex" in kwargs["formatting"]:
return self_repr + " -- " + distr_repr
else:
return self_repr + "-" + distr_repr


class _DiscreteBounded(_Bounded, Discrete):
def __init__(self, distribution, lower, upper, transform="infer", *args, **kwargs):
Expand Down
88 changes: 14 additions & 74 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextvars
import inspect
import functools
import multiprocessing
import sys
import types
Expand Down Expand Up @@ -41,7 +41,8 @@
resize_from_dims,
resize_from_observed,
)
from pymc3.util import UNSET, get_repr_for_variable
from pymc3.printing import str_for_dist
from pymc3.util import UNSET
from pymc3.vartypes import string_types

__all__ = [
Expand Down Expand Up @@ -222,7 +223,17 @@ def __new__(
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
rv_out.tag.test_value = initval

return model.register_rv(rv_out, name, observed, total_size, dims=dims, transform=transform)
rv_out = model.register_rv(
rv_out, name, observed, total_size, dims=dims, transform=transform
)

# add in pretty-printing support
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
rv_out._repr_latex_ = types.MethodType(
functools.partial(str_for_dist, formatting="latex"), rv_out
)

return rv_out

@classmethod
def dist(
Expand Down Expand Up @@ -313,77 +324,6 @@ def dist(

return rv_out

def _distr_parameters_for_repr(self):
"""Return the names of the parameters for this distribution (e.g. "mu"
and "sigma" for Normal). Used in generating string (and LaTeX etc.)
representations of Distribution objects. By default based on inspection
of __init__, but can be overwritten if necessary (e.g. to avoid including
"sd" and "tau").
"""
return inspect.getfullargspec(self.__init__).args[1:]

def _distr_name_for_repr(self):
return self.__class__.__name__

def _str_repr(self, name=None, dist=None, formatting="plain"):
"""
Generate string representation for this distribution, optionally
including LaTeX markup (formatting='latex').

Parameters
----------
name : str
name of the distribution
dist : Distribution
the distribution object
formatting : str
one of { "latex", "plain", "latex_with_params", "plain_with_params" }
"""
if dist is None:
dist = self
if name is None:
name = "[unnamed]"
supported_formattings = {"latex", "plain", "latex_with_params", "plain_with_params"}
if not formatting in supported_formattings:
raise ValueError(f"Unsupported formatting ''. Choose one of {supported_formattings}.")

param_names = self._distr_parameters_for_repr()
param_values = [
get_repr_for_variable(getattr(dist, x), formatting=formatting) for x in param_names
]

if "latex" in formatting:
param_string = ",~".join(
[fr"\mathit{{{name}}}={value}" for name, value in zip(param_names, param_values)]
)
if formatting == "latex_with_params":
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
)
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}$".format(
var_name=name, distr_name=dist._distr_name_for_repr()
)
else:
# one of the plain formattings
param_string = ", ".join(
[f"{name}={value}" for name, value in zip(param_names, param_values)]
)
if formatting == "plain_with_params":
return f"{name} ~ {dist._distr_name_for_repr()}({param_string})"
return f"{name} ~ {dist._distr_name_for_repr()}"

def __str__(self, **kwargs):
try:
return self._str_repr(formatting="plain", **kwargs)
except:
return super().__str__()

def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
"""Magic method name for IPython to use for LaTeX formatting."""
return self._str_repr(formatting=formatting, **kwargs)

__latex__ = _repr_latex_


class NoDistribution(Distribution):
def __init__(
Expand Down
14 changes: 0 additions & 14 deletions pymc3/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,6 @@ def random(self, point=None, size=None):
# else:
# return np.array([self.function(*params) for _ in range(size[0])])

def _str_repr(self, name=None, dist=None, formatting="plain"):
if dist is None:
dist = self
name = name
function = dist.function.__name__
params = ", ".join([var.name for var in dist.params])
sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat
distance = getattr(self.distance, "__name__", self.distance.__class__.__name__)

if "latex" in formatting:
return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
else:
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"


def identity(x):
"""Identity function, used as a summary statistics."""
Expand Down
75 changes: 34 additions & 41 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import collections
import itertools
import functools
import threading
import types
import warnings

from sys import modules
Expand Down Expand Up @@ -668,6 +669,13 @@ def __init__(
self.deterministics = treelist()
self.potentials = treelist()

from pymc3.printing import str_for_model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not move this to the global level?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid the circular import (printing depends on model and vice versa). I don't like the look of this either, but it seems to be the recommended way to avoid circular imports.


self.str_repr = types.MethodType(str_for_model, self)
self._repr_latex_ = types.MethodType(
functools.partial(str_for_model, formatting="latex"), self
)

@property
def model(self):
return self
Expand Down Expand Up @@ -1628,46 +1636,6 @@ def point_logps(self, point=None, round_vals=2):
name="Log-probability of test_point",
)

def _str_repr(self, formatting="plain", **kwargs):
all_rv = itertools.chain(self.unobserved_RVs, self.observed_RVs)

if "latex" in formatting:
rv_reprs = [rv.__latex__(formatting=formatting) for rv in all_rv]
rv_reprs = [
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
for rv_repr in rv_reprs
if rv_repr is not None
]
return r"""$$
\begin{{array}}{{rcl}}
{}
\end{{array}}
$$""".format(
"\\\\".join(rv_reprs)
)
else:
rv_reprs = [rv.__str__() for rv in all_rv]
rv_reprs = [
rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr
]
# align vars on their ~
names = [s[: s.index("~") - 1] for s in rv_reprs]
distrs = [s[s.index("~") + 2 :] for s in rv_reprs]
maxlen = str(max(len(x) for x in names))
rv_reprs = [
("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d)
for n, d in zip(names, distrs)
]
return "\n".join(rv_reprs)

def __str__(self, **kwargs):
return self._str_repr(formatting="plain", **kwargs)

def _repr_latex_(self, *, formatting="latex", **kwargs):
return self._str_repr(formatting=formatting, **kwargs)

__latex__ = _repr_latex_


# this is really disgusting, but it breaks a self-loop: I can't pass Model
# itself as context class init arg.
Expand Down Expand Up @@ -1821,6 +1789,18 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
model.deterministics.append(var)
model.add_random_variable(var, dims)

from pymc3.printing import str_for_potential_or_deterministic

var.str_repr = types.MethodType(
functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
)
var._repr_latex_ = types.MethodType(
functools.partial(
str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
),
var,
)

return var


Expand All @@ -1841,4 +1821,17 @@ def Potential(name, var, model=None):
var.tag.scaling = None
model.potentials.append(var)
model.add_random_variable(var)

from pymc3.printing import str_for_potential_or_deterministic

var.str_repr = types.MethodType(
functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
)
var._repr_latex_ = types.MethodType(
functools.partial(
str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
),
var,
)

return var
Loading