Skip to content

Commit b4ce5ae

Browse files
committed
Deprecate str_for_symbolic_dist in favor of str_for_dist
* Display the dists and their inputs used by SymbolicDistributions
1 parent 2e3247d commit b4ce5ae

File tree

6 files changed

+64
-31
lines changed

6 files changed

+64
-31
lines changed

pymc/distributions/censored.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class CensoredRV(SymbolicRandomVariable):
3030
"""Censored random variable"""
3131

3232
inline_aeppl = True
33+
_print_name = ("Censored", "\\operatorname{Censored}")
3334

3435

3536
class Censored(SymbolicDistribution):

pymc/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
find_size,
5353
shape_from_dims,
5454
)
55-
from pymc.printing import str_for_dist, str_for_symbolic_dist
55+
from pymc.printing import str_for_dist
5656
from pymc.util import UNSET
5757
from pymc.vartypes import string_types
5858

@@ -531,9 +531,9 @@ def __new__(
531531
initval=initval,
532532
)
533533
# add in pretty-printing support
534-
rv_out.str_repr = types.MethodType(str_for_symbolic_dist, rv_out)
534+
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
535535
rv_out._repr_latex_ = types.MethodType(
536-
functools.partial(str_for_symbolic_dist, formatting="latex"), rv_out
536+
functools.partial(str_for_dist, formatting="latex"), rv_out
537537
)
538538

539539
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")

pymc/distributions/mixture.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class MarginalMixtureRV(SymbolicRandomVariable):
4545
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""
4646

4747
default_output = 1
48+
_print_name = ("MarginalMixture", "\\operatorname{MarginalMixture}")
4849

4950
def update(self, node: Node):
5051
# Update for the internal mix_indexes RV

pymc/distributions/timeseries.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ class AutoRegressiveRV(SymbolicRandomVariable):
335335
default_output = 1
336336
ar_order: int
337337
constant_term: bool
338+
_print_name = ("AR", "\\operatorname{AR}")
338339

339340
def __init__(self, *args, ar_order, constant_term, **kwargs):
340341
self.ar_order = ar_order

pymc/printing.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
from aesara.tensor.basic import TensorVariable, Variable
2121
from aesara.tensor.elemwise import DimShuffle
2222
from aesara.tensor.random.basic import RandomVariable
23+
from aesara.tensor.random.var import (
24+
RandomGeneratorSharedVariable,
25+
RandomStateSharedVariable,
26+
)
2327
from aesara.tensor.var import TensorConstant
2428

2529
from pymc.model import Model
@@ -31,40 +35,62 @@
3135
]
3236

3337

34-
def str_for_dist(rv: TensorVariable, formatting: str = "plain", include_params: bool = True) -> str:
35-
"""Make a human-readable string representation of a RandomVariable in a model, either
38+
def str_for_dist(
39+
dist: TensorVariable, formatting: str = "plain", include_params: bool = True
40+
) -> str:
41+
"""Make a human-readable string representation of a Distribution in a model, either
3642
LaTeX or plain, optionally with distribution parameter values included."""
3743

3844
if include_params:
3945
# first 3 args are always (rng, size, dtype), rest is relevant for distribution
40-
dist_args = [_str_for_input_var(x, formatting=formatting) for x in rv.owner.inputs[3:]]
46+
if isinstance(dist.owner.op, RandomVariable):
47+
dist_args = [
48+
_str_for_input_var(x, formatting=formatting) for x in dist.owner.inputs[3:]
49+
]
50+
else:
51+
dist_args = [
52+
_str_for_input_var(x, formatting=formatting)
53+
for x in dist.owner.inputs
54+
if not isinstance(x, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
55+
]
56+
57+
print_name = dist.name
4158

42-
print_name = rv.name if rv.name is not None else "<unnamed>"
4359
if "latex" in formatting:
44-
print_name = r"\text{" + _latex_escape(print_name) + "}"
45-
dist_name = rv.owner.op._print_name[1]
60+
if print_name is not None:
61+
print_name = r"\text{" + _latex_escape(dist.name) + "}"
62+
63+
op_name = (
64+
dist.owner.op._print_name[1]
65+
if hasattr(dist.owner.op, "_print_name")
66+
else r"\\operatorname{Unknown}"
67+
)
4668
if include_params:
47-
return r"${} \sim {}({})$".format(print_name, dist_name, ",~".join(dist_args))
69+
if print_name:
70+
return r"${} \sim {}({})$".format(print_name, op_name, ",~".join(dist_args))
71+
else:
72+
return r"${}({})$".format(op_name, ",~".join(dist_args))
73+
4874
else:
49-
return rf"${print_name} \sim {dist_name}$"
75+
if print_name:
76+
return rf"${print_name} \sim {op_name}$"
77+
else:
78+
return rf"${op_name}$"
79+
5080
else: # plain
51-
dist_name = rv.owner.op._print_name[0]
81+
dist_name = (
82+
dist.owner.op._print_name[0] if hasattr(dist.owner.op, "_print_name") else "Unknown"
83+
)
5284
if include_params:
53-
return r"{} ~ {}({})".format(print_name, dist_name, ", ".join(dist_args))
85+
if print_name:
86+
return r"{} ~ {}({})".format(print_name, dist_name, ", ".join(dist_args))
87+
else:
88+
return r"{}({})".format(dist_name, ", ".join(dist_args))
5489
else:
55-
return rf"{print_name} ~ {dist_name}"
56-
57-
58-
def str_for_symbolic_dist(
59-
rv: TensorVariable, formatting: str = "plain", include_params: bool = True
60-
) -> str:
61-
"""Make a human-readable string representation of a SymbolicDistribution in a model,
62-
either LaTeX or plain, optionally with distribution parameter values included."""
63-
64-
if "latex" in formatting:
65-
return rf"$\text{{{rv.name}}} \sim \text{{{rv.owner.op}}}$"
66-
else:
67-
return rf"{rv.name} ~ {rv.owner.op}"
90+
if print_name:
91+
return rf"{print_name} ~ {dist_name}"
92+
else:
93+
return dist_name
6894

6995

7096
def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str:
@@ -145,7 +171,11 @@ def _is_potential_or_determinstic(var: Variable) -> bool:
145171

146172

147173
def _str_for_input_rv(var: Variable, formatting: str) -> str:
148-
_str = var.name if var.name is not None else "<unnamed>"
174+
_str = (
175+
var.name
176+
if var.name is not None
177+
else str_for_dist(var, formatting=formatting, include_params=True)
178+
)
149179
if "latex" in formatting:
150180
return r"\text{" + _latex_escape(_str) + "}"
151181
else:

pymc/tests/test_printing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def setup_class(self):
9292
r"beta ~ N(0, 10)",
9393
r"Z ~ N(f(), f())",
9494
r"nb_with_p_n ~ NB(10, nbp)",
95-
r"zip ~ MarginalMixtureRV{inline=False}",
95+
r"zip ~ MarginalMixture(f(), DiracDelta(0), Pois(5))",
9696
r"Y_obs ~ N(mu, sigma)",
9797
r"pot ~ Potential(f(beta, alpha))",
9898
],
@@ -103,7 +103,7 @@ def setup_class(self):
103103
r"beta ~ N",
104104
r"Z ~ N",
105105
r"nb_with_p_n ~ NB",
106-
r"zip ~ MarginalMixtureRV{inline=False}",
106+
r"zip ~ MarginalMixture",
107107
r"Y_obs ~ N",
108108
r"pot ~ Potential",
109109
],
@@ -114,7 +114,7 @@ def setup_class(self):
114114
r"$\text{beta} \sim \operatorname{N}(0,~10)$",
115115
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
116116
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
117-
r"$\text{zip} \sim \text{MarginalMixtureRV{inline=False}}$",
117+
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\text{\$\operatorname{DiracDelta}(0)\$},~\text{\$\operatorname{Pois}(5)\$})$",
118118
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
119119
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
120120
],
@@ -125,7 +125,7 @@ def setup_class(self):
125125
r"$\text{beta} \sim \operatorname{N}$",
126126
r"$\text{Z} \sim \operatorname{N}$",
127127
r"$\text{nb_with_p_n} \sim \operatorname{NB}$",
128-
r"$\text{zip} \sim \text{MarginalMixtureRV{inline=False}}$",
128+
r"$\text{zip} \sim \operatorname{MarginalMixture}$",
129129
r"$\text{Y_obs} \sim \operatorname{N}$",
130130
r"$\text{pot} \sim \operatorname{Potential}$",
131131
],

0 commit comments

Comments
 (0)