Skip to content

Commit 938604c

Browse files
ricardoV94twiecki
authored andcommitted
Add temporary fix for pretty representation of SymbolicDistributions
1 parent 57c6a8f commit 938604c

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

pymc/distributions/distribution.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
resize_from_dims,
4848
resize_from_observed,
4949
)
50-
from pymc.printing import str_for_dist
50+
from pymc.printing import str_for_dist, str_for_symbolic_dist
5151
from pymc.util import UNSET
5252
from pymc.vartypes import string_types
5353

@@ -483,15 +483,11 @@ def __new__(
483483
transform=transform,
484484
initval=initval,
485485
)
486-
487-
# TODO: Refactor this
488486
# add in pretty-printing support
489-
rv_out.str_repr = lambda *args, **kwargs: name
490-
rv_out._repr_latex_ = f"\\text{name}"
491-
# rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
492-
# rv_out._repr_latex_ = types.MethodType(
493-
# functools.partial(str_for_dist, formatting="latex"), rv_out
494-
# )
487+
rv_out.str_repr = types.MethodType(str_for_symbolic_dist, rv_out)
488+
rv_out._repr_latex_ = types.MethodType(
489+
functools.partial(str_for_symbolic_dist, formatting="latex"), rv_out
490+
)
495491

496492
return rv_out
497493

pymc/printing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params:
5555
return rf"{print_name} ~ {dist_name}"
5656

5757

58+
def str_for_symbolic_dist(
59+
rv: TensorVariable, formatting: str = "plain", include_params: bool = True
60+
) -> str:
61+
"""Make a human-readable string representation of a SymbolicDistribution in a model,
62+
either LaTeX or plain, optionally with distribution parameter values included."""
63+
64+
if "latex" in formatting:
65+
return rf"$\text{{{rv.name}}} \sim \text{{{rv.owner.op}}}$"
66+
else:
67+
return rf"{rv.name} ~ {rv.owner.op}"
68+
69+
5870
def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str:
5971
"""Make a human-readable string representation of Model, listing all random variables
6072
and their distributions, optionally including parameter values."""

pymc/tests/test_printing.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
NegativeBinomial,
1010
Normal,
1111
Uniform,
12+
ZeroInflatedPoisson,
1213
)
1314
from pymc.math import dot
1415
from pymc.model import Deterministic, Model, Potential
@@ -47,6 +48,9 @@ def setup_class(self):
4748
# )
4849
nb2 = NegativeBinomial("nb_with_p_n", p=Uniform("nbp"), n=10)
4950

51+
# Symbolic distribution
52+
zip = ZeroInflatedPoisson("zip", 0.5, 5)
53+
5054
# Expected value of outcome
5155
mu = Deterministic("mu", floatX(alpha + dot(X, b)))
5256

@@ -76,7 +80,7 @@ def setup_class(self):
7680
# add a potential as well
7781
pot = Potential("pot", mu**2)
7882

79-
self.distributions = [alpha, sigma, mu, b, Z, nb2, Y_obs, pot]
83+
self.distributions = [alpha, sigma, mu, b, Z, nb2, zip, Y_obs, pot]
8084
self.deterministics_or_potentials = [mu, pot]
8185
# tuples of (formatting, include_params
8286
self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)]
@@ -88,6 +92,7 @@ def setup_class(self):
8892
r"beta ~ N(0, 10)",
8993
r"Z ~ N(f(), f())",
9094
r"nb_with_p_n ~ NB(10, nbp)",
95+
r"zip ~ MarginalMixtureRV{inline=False}",
9196
r"Y_obs ~ N(mu, sigma)",
9297
r"pot ~ Potential(f(beta, alpha))",
9398
],
@@ -98,6 +103,7 @@ def setup_class(self):
98103
r"beta ~ N",
99104
r"Z ~ N",
100105
r"nb_with_p_n ~ NB",
106+
r"zip ~ MarginalMixtureRV{inline=False}",
101107
r"Y_obs ~ N",
102108
r"pot ~ Potential",
103109
],
@@ -108,6 +114,7 @@ def setup_class(self):
108114
r"$\text{beta} \sim \operatorname{N}(0,~10)$",
109115
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
110116
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
117+
r"$\text{zip} \sim \text{MarginalMixtureRV{inline=False}}$",
111118
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
112119
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
113120
],
@@ -118,6 +125,7 @@ def setup_class(self):
118125
r"$\text{beta} \sim \operatorname{N}$",
119126
r"$\text{Z} \sim \operatorname{N}$",
120127
r"$\text{nb_with_p_n} \sim \operatorname{NB}$",
128+
r"$\text{zip} \sim \text{MarginalMixtureRV{inline=False}}$",
121129
r"$\text{Y_obs} \sim \operatorname{N}$",
122130
r"$\text{pot} \sim \operatorname{Potential}$",
123131
],

0 commit comments

Comments
 (0)