Skip to content

Commit 27b1a7d

Browse files
Add str/repr formatting options and change defaults accordingly (#4260)
* add str/repr formatting options and change defaults accordingly + new options "latex_with_params" and "plain_with_params" replace the current default behavior of including input parameters (going back to default behaviour of the 3.9.3 release) + __latex__ and _repr_latex default to "latex_with_params" + __str__ and _str_repr default to "plain" + new formatting kwarg for model_to_graphviz enables switching between "plain" (default) and "plain_with_params" * update if conditions, formatting defaults and tests + latex formatting should be detected by `if "latex" in formatting` to catch both format options + all latex reprs except for an entire model default to "latex_with_params" + tests now cover cases with and without params * consolidate regression test into existing one + test_issue_4186 was merged into TestStrAndLatexRepr and now all 4 formatting options are covered. * check graphviz results with all four formatting options
1 parent 8b3b701 commit 27b1a7d

File tree

9 files changed

+160
-91
lines changed

9 files changed

+160
-91
lines changed

pymc3/distributions/bart.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
246246
alpha = self.alpha
247247
m = self.m
248248

249-
if formatting == "latex":
249+
if "latex" in formatting:
250250
return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$"
251251
else:
252252
return f"{name} ~ BART(alpha = {alpha}, m = {m})"

pymc3/distributions/bound.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,13 @@ def _distr_name_for_repr(self):
157157

158158
def _str_repr(self, **kwargs):
159159
distr_repr = self._wrapped._str_repr(**{**kwargs, "dist": self._wrapped})
160-
if "formatting" in kwargs and kwargs["formatting"] == "latex":
160+
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
161161
distr_repr = distr_repr[distr_repr.index(r" \sim") + 6 :]
162162
else:
163163
distr_repr = distr_repr[distr_repr.index(" ~") + 3 :]
164164
self_repr = super()._str_repr(**kwargs)
165165

166-
if "formatting" in kwargs and kwargs["formatting"] == "latex":
166+
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
167167
return self_repr + " -- " + distr_repr
168168
else:
169169
return self_repr + "-" + distr_repr

pymc3/distributions/distribution.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -164,44 +164,61 @@ def _distr_name_for_repr(self):
164164
return self.__class__.__name__
165165

166166
def _str_repr(self, name=None, dist=None, formatting="plain"):
167-
"""Generate string representation for this distribution, optionally
167+
"""
168+
Generate string representation for this distribution, optionally
168169
including LaTeX markup (formatting='latex').
170+
171+
Parameters
172+
----------
173+
name : str
174+
name of the distribution
175+
dist : Distribution
176+
the distribution object
177+
formatting : str
178+
one of { "latex", "plain", "latex_with_params", "plain_with_params" }
169179
"""
170180
if dist is None:
171181
dist = self
172182
if name is None:
173183
name = "[unnamed]"
184+
supported_formattings = {"latex", "plain", "latex_with_params", "plain_with_params"}
185+
if not formatting in supported_formattings:
186+
raise ValueError(f"Unsupported formatting ''. Choose one of {supported_formattings}.")
174187

175188
param_names = self._distr_parameters_for_repr()
176189
param_values = [
177190
get_repr_for_variable(getattr(dist, x), formatting=formatting) for x in param_names
178191
]
179192

180-
if formatting == "latex":
193+
if "latex" in formatting:
181194
param_string = ",~".join(
182195
[fr"\mathit{{{name}}}={value}" for name, value in zip(param_names, param_values)]
183196
)
184-
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
185-
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
197+
if formatting == "latex_with_params":
198+
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
199+
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
200+
)
201+
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}$".format(
202+
var_name=name, distr_name=dist._distr_name_for_repr()
186203
)
187204
else:
188-
# 'plain' is default option
205+
# one of the plain formattings
189206
param_string = ", ".join(
190207
[f"{name}={value}" for name, value in zip(param_names, param_values)]
191208
)
192-
return "{var_name} ~ {distr_name}({params})".format(
193-
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
194-
)
209+
if formatting == "plain_with_params":
210+
return f"{name} ~ {dist._distr_name_for_repr()}({param_string})"
211+
return f"{name} ~ {dist._distr_name_for_repr()}"
195212

196213
def __str__(self, **kwargs):
197214
try:
198215
return self._str_repr(formatting="plain", **kwargs)
199216
except:
200217
return super().__str__()
201218

202-
def _repr_latex_(self, **kwargs):
219+
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
203220
"""Magic method name for IPython to use for LaTeX formatting."""
204-
return self._str_repr(formatting="latex", **kwargs)
221+
return self._str_repr(formatting=formatting, **kwargs)
205222

206223
def logp_nojac(self, *args, **kwargs):
207224
"""Return the logp, but do not include a jacobian term for transforms.

pymc3/distributions/simulator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
126126
sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat
127127
distance = getattr(self.distance, "__name__", self.distance.__class__.__name__)
128128

129-
if formatting == "latex":
129+
if "latex" in formatting:
130130
return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
131131
else:
132132
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"

pymc3/model.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __rmatmul__(self, other):
6565

6666
def _str_repr(self, name=None, dist=None, formatting="plain"):
6767
if getattr(self, "distribution", None) is None:
68-
if formatting == "latex":
68+
if "latex" in formatting:
6969
return None
7070
else:
7171
return super().__str__()
@@ -76,8 +76,8 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
7676
dist = self.distribution
7777
return self.distribution._str_repr(name=name, dist=dist, formatting=formatting)
7878

79-
def _repr_latex_(self, **kwargs):
80-
return self._str_repr(formatting="latex", **kwargs)
79+
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
80+
return self._str_repr(formatting=formatting, **kwargs)
8181

8282
def __str__(self, **kwargs):
8383
try:
@@ -1375,8 +1375,8 @@ def check_test_point(self, test_point=None, round_vals=2):
13751375
def _str_repr(self, formatting="plain", **kwargs):
13761376
all_rv = itertools.chain(self.unobserved_RVs, self.observed_RVs)
13771377

1378-
if formatting == "latex":
1379-
rv_reprs = [rv.__latex__() for rv in all_rv]
1378+
if "latex" in formatting:
1379+
rv_reprs = [rv.__latex__(formatting=formatting) for rv in all_rv]
13801380
rv_reprs = [
13811381
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
13821382
for rv_repr in rv_reprs
@@ -1407,8 +1407,8 @@ def _str_repr(self, formatting="plain", **kwargs):
14071407
def __str__(self, **kwargs):
14081408
return self._str_repr(formatting="plain", **kwargs)
14091409

1410-
def _repr_latex_(self, **kwargs):
1411-
return self._str_repr(formatting="latex", **kwargs)
1410+
def _repr_latex_(self, *, formatting="latex", **kwargs):
1411+
return self._str_repr(formatting=formatting, **kwargs)
14121412

14131413
__latex__ = _repr_latex_
14141414

@@ -1874,24 +1874,27 @@ def _walk_up_rv(rv, formatting="plain"):
18741874
all_rvs.extend(_walk_up_rv(parent, formatting=formatting))
18751875
else:
18761876
name = rv.name if rv.name else "Constant"
1877-
fmt = r"\text{{{name}}}" if formatting == "latex" else "{name}"
1877+
fmt = r"\text{{{name}}}" if "latex" in formatting else "{name}"
18781878
all_rvs.append(fmt.format(name=name))
18791879
return all_rvs
18801880

18811881

18821882
class DeterministicWrapper(tt.TensorVariable):
18831883
def _str_repr(self, formatting="plain"):
1884-
if formatting == "latex":
1885-
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
1886-
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
1887-
)
1884+
if "latex" in formatting:
1885+
if formatting == "latex_with_params":
1886+
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
1887+
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
1888+
)
1889+
return fr"$\text{{{self.name}}} \sim \text{{Deterministic}}$"
18881890
else:
1889-
return "{name} ~ Deterministic({args})".format(
1890-
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting))
1891-
)
1891+
if formatting == "plain_with_params":
1892+
args = ", ".join(_walk_up_rv(self, formatting=formatting))
1893+
return f"{self.name} ~ Deterministic({args})"
1894+
return f"{self.name} ~ Deterministic"
18921895

1893-
def _repr_latex_(self):
1894-
return self._str_repr(formatting="latex")
1896+
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
1897+
return self._str_repr(formatting=formatting)
18951898

18961899
__latex__ = _repr_latex_
18971900

pymc3/model_graph.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def update_input_map(key: str, val: Set[VarName]):
121121
pass
122122
return input_map
123123

124-
def _make_node(self, var_name, graph):
124+
def _make_node(self, var_name, graph, *, formatting: str = "plain"):
125125
"""Attaches the given variable to a graphviz Digraph"""
126126
v = self.model[var_name]
127127

@@ -146,7 +146,7 @@ def _make_node(self, var_name, graph):
146146
elif isinstance(v, SharedVariable):
147147
label = f"{var_name}\n~\nData"
148148
else:
149-
label = str(v).replace(" ~ ", "\n~\n")
149+
label = v._str_repr(formatting=formatting).replace(" ~ ", "\n~\n")
150150

151151
graph.node(var_name.replace(":", "&"), label, **attrs)
152152

@@ -181,7 +181,7 @@ def get_plates(self):
181181
plates[shape].add(var_name)
182182
return plates
183183

184-
def make_graph(self):
184+
def make_graph(self, formatting: str = "plain"):
185185
"""Make graphviz Digraph of PyMC3 model
186186
187187
Returns
@@ -205,20 +205,20 @@ def make_graph(self):
205205
# must be preceded by 'cluster' to get a box around it
206206
with graph.subgraph(name="cluster" + label) as sub:
207207
for var_name in var_names:
208-
self._make_node(var_name, sub)
208+
self._make_node(var_name, sub, formatting=formatting)
209209
# plate label goes bottom right
210210
sub.attr(label=label, labeljust="r", labelloc="b", style="rounded")
211211
else:
212212
for var_name in var_names:
213-
self._make_node(var_name, graph)
213+
self._make_node(var_name, graph, formatting=formatting)
214214

215215
for key, values in self.make_compute_graph().items():
216216
for value in values:
217217
graph.edge(value.replace(":", "&"), key.replace(":", "&"))
218218
return graph
219219

220220

221-
def model_to_graphviz(model=None):
221+
def model_to_graphviz(model=None, *, formatting: str = "plain"):
222222
"""Produce a graphviz Digraph from a PyMC3 model.
223223
224224
Requires graphviz, which may be installed most easily with
@@ -228,6 +228,15 @@ def model_to_graphviz(model=None):
228228
and then `pip install graphviz` to get the python bindings. See
229229
http://graphviz.readthedocs.io/en/stable/manual.html
230230
for more information.
231+
232+
Parameters
233+
----------
234+
model : pm.Model
235+
The model to plot. Not required when called from inside a modelcontext.
236+
formatting : str
237+
one of { "plain", "plain_with_params" }
231238
"""
239+
if not "plain" in formatting:
240+
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
232241
model = pm.modelcontext(model)
233-
return ModelGraph(model).make_graph()
242+
return ModelGraph(model).make_graph(formatting=formatting)

pymc3/tests/test_data_container.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,28 @@ def test_model_to_graphviz_for_model_with_data_container(self):
179179
pm.Normal("obs", beta * x, obs_sigma, observed=y)
180180
pm.sample(1000, init=None, tune=1000, chains=1)
181181

182-
g = pm.model_to_graphviz(model)
183-
184-
# Data node rendered correctly?
185-
text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]'
186-
assert text in g.source
187-
# Didn't break ordinary variables?
188-
text = 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]'
189-
assert text in g.source
190-
text = f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]'
191-
assert text in g.source
182+
for formatting in {"latex", "latex_with_params"}:
183+
with pytest.raises(ValueError, match="Unsupported formatting"):
184+
pm.model_to_graphviz(model, formatting=formatting)
185+
186+
exp_without = [
187+
'x [label="x\n~\nData" shape=box style="rounded, filled"]',
188+
'beta [label="beta\n~\nNormal"]',
189+
'obs [label="obs\n~\nNormal" style=filled]',
190+
]
191+
exp_with = [
192+
'x [label="x\n~\nData" shape=box style="rounded, filled"]',
193+
'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]',
194+
f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]',
195+
]
196+
for formatting, expected_substrings in [
197+
("plain", exp_without),
198+
("plain_with_params", exp_with),
199+
]:
200+
g = pm.model_to_graphviz(model, formatting=formatting)
201+
# check formatting of RV nodes
202+
for expected in expected_substrings:
203+
assert expected in g.source
192204

193205
def test_explicit_coords(self):
194206
N_rows = 5

0 commit comments

Comments
 (0)