Skip to content

Commit 30214af

Browse files
committed
Merge remote-tracking branch 'upstream/master' into less-strict-test
2 parents 190e986 + 5ff9bbc commit 30214af

19 files changed

+832
-126
lines changed

RELEASE-NOTES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pip install theano-pymc
2626
This new version of `Theano-PyMC` comes with an experimental JAX backend which, when combined with the new and experimental JAX samplers in PyMC3, can greatly speed up sampling in your model. As this is still very new, please do not use it in production yet but do test it out and let us know if anything breaks and what results you are seeing, especially speed-wise.
2727

2828
### New features
29+
- New experimental JAX samplers in `pymc3.sample_jax` (see [notebook](https://docs.pymc.io/notebooks/GLM-hierarchical-jax.html) and [#4247](https://github.com/pymc-devs/pymc3/pull/4247)). Requires JAX and either TFP or numpyro.
2930
- Add MLDA, a new stepper for multilevel sampling. MLDA can be used when a hierarchy of approximate posteriors of varying accuracy is available, offering improved sampling efficiency especially in high-dimensional problems and/or where gradients are not available (see [#3926](https://github.com/pymc-devs/pymc3/pull/3926))
3031
- Add Bayesian Additive Regression Trees (BARTs) [#4183](https://github.com/pymc-devs/pymc3/pull/4183))
3132
- Added `pymc3.gp.cov.Circular` kernel for Gaussian Processes on circular domains, e.g. the unit circle (see [#4082](https://github.com/pymc-devs/pymc3/pull/4082)).
@@ -42,6 +43,8 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which,
4243
- Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)).
4344
- Numerically improved stickbreaking transformation - e.g. for the `Dirichlet` distribution. [#4129](https://github.com/pymc-devs/pymc3/pull/4129)
4445
- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169)
46+
- Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116))
47+
- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721))
4548

4649
### Documentation
4750
- Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)).

docs/source/notebooks/GLM-hierarchical-jax.ipynb

Lines changed: 384 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/table_of_contents_examples.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,6 @@ Gallery.contents = {
6464
"MLDA_introduction": "MCMC",
6565
"MLDA_simple_linear_regression": "MCMC",
6666
"MLDA_gravity_surveying": "MCMC",
67-
"MLDA_variance_reduction_linear_regression": "MCMC"
67+
"MLDA_variance_reduction_linear_regression": "MCMC",
68+
"GLM-hierarchical-jax": "MCMC"
6869
}

pymc3/distributions/bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
246246
alpha = self.alpha
247247
m = self.m
248248

249-
if formatting == "latex":
249+
if "latex" in formatting:
250250
return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$"
251251
else:
252252
return f"{name} ~ BART(alpha = {alpha}, m = {m})"

pymc3/distributions/bound.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,13 @@ def _distr_name_for_repr(self):
157157

158158
def _str_repr(self, **kwargs):
159159
distr_repr = self._wrapped._str_repr(**{**kwargs, "dist": self._wrapped})
160-
if "formatting" in kwargs and kwargs["formatting"] == "latex":
160+
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
161161
distr_repr = distr_repr[distr_repr.index(r" \sim") + 6 :]
162162
else:
163163
distr_repr = distr_repr[distr_repr.index(" ~") + 3 :]
164164
self_repr = super()._str_repr(**kwargs)
165165

166-
if "formatting" in kwargs and kwargs["formatting"] == "latex":
166+
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
167167
return self_repr + " -- " + distr_repr
168168
else:
169169
return self_repr + "-" + distr_repr

pymc3/distributions/distribution.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,44 +164,61 @@ def _distr_name_for_repr(self):
164164
return self.__class__.__name__
165165

166166
def _str_repr(self, name=None, dist=None, formatting="plain"):
167-
"""Generate string representation for this distribution, optionally
167+
"""
168+
Generate string representation for this distribution, optionally
168169
including LaTeX markup (formatting='latex').
170+
171+
Parameters
172+
----------
173+
name : str
174+
name of the distribution
175+
dist : Distribution
176+
the distribution object
177+
formatting : str
178+
one of { "latex", "plain", "latex_with_params", "plain_with_params" }
169179
"""
170180
if dist is None:
171181
dist = self
172182
if name is None:
173183
name = "[unnamed]"
184+
supported_formattings = {"latex", "plain", "latex_with_params", "plain_with_params"}
185+
if not formatting in supported_formattings:
186+
raise ValueError(f"Unsupported formatting ''. Choose one of {supported_formattings}.")
174187

175188
param_names = self._distr_parameters_for_repr()
176189
param_values = [
177190
get_repr_for_variable(getattr(dist, x), formatting=formatting) for x in param_names
178191
]
179192

180-
if formatting == "latex":
193+
if "latex" in formatting:
181194
param_string = ",~".join(
182195
[fr"\mathit{{{name}}}={value}" for name, value in zip(param_names, param_values)]
183196
)
184-
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
185-
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
197+
if formatting == "latex_with_params":
198+
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
199+
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
200+
)
201+
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}$".format(
202+
var_name=name, distr_name=dist._distr_name_for_repr()
186203
)
187204
else:
188-
# 'plain' is default option
205+
# one of the plain formattings
189206
param_string = ", ".join(
190207
[f"{name}={value}" for name, value in zip(param_names, param_values)]
191208
)
192-
return "{var_name} ~ {distr_name}({params})".format(
193-
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
194-
)
209+
if formatting == "plain_with_params":
210+
return f"{name} ~ {dist._distr_name_for_repr()}({param_string})"
211+
return f"{name} ~ {dist._distr_name_for_repr()}"
195212

196213
def __str__(self, **kwargs):
197214
try:
198215
return self._str_repr(formatting="plain", **kwargs)
199216
except:
200217
return super().__str__()
201218

202-
def _repr_latex_(self, **kwargs):
219+
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
203220
"""Magic method name for IPython to use for LaTeX formatting."""
204-
return self._str_repr(formatting="latex", **kwargs)
221+
return self._str_repr(formatting=formatting, **kwargs)
205222

206223
def logp_nojac(self, *args, **kwargs):
207224
"""Return the logp, but do not include a jacobian term for transforms.

pymc3/distributions/simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
126126
sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat
127127
distance = getattr(self.distance, "__name__", self.distance.__class__.__name__)
128128

129-
if formatting == "latex":
129+
if "latex" in formatting:
130130
return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
131131
else:
132132
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"

pymc3/model.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __rmatmul__(self, other):
6565

6666
def _str_repr(self, name=None, dist=None, formatting="plain"):
6767
if getattr(self, "distribution", None) is None:
68-
if formatting == "latex":
68+
if "latex" in formatting:
6969
return None
7070
else:
7171
return super().__str__()
@@ -76,8 +76,8 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
7676
dist = self.distribution
7777
return self.distribution._str_repr(name=name, dist=dist, formatting=formatting)
7878

79-
def _repr_latex_(self, **kwargs):
80-
return self._str_repr(formatting="latex", **kwargs)
79+
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
80+
return self._str_repr(formatting=formatting, **kwargs)
8181

8282
def __str__(self, **kwargs):
8383
try:
@@ -1368,15 +1368,15 @@ def check_test_point(self, test_point=None, round_vals=2):
13681368
test_point = self.test_point
13691369

13701370
return Series(
1371-
{RV.name: np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs},
1371+
{RV.name: np.round(RV.logp(test_point), round_vals) for RV in self.basic_RVs},
13721372
name="Log-probability of test_point",
13731373
)
13741374

13751375
def _str_repr(self, formatting="plain", **kwargs):
13761376
all_rv = itertools.chain(self.unobserved_RVs, self.observed_RVs)
13771377

1378-
if formatting == "latex":
1379-
rv_reprs = [rv.__latex__() for rv in all_rv]
1378+
if "latex" in formatting:
1379+
rv_reprs = [rv.__latex__(formatting=formatting) for rv in all_rv]
13801380
rv_reprs = [
13811381
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
13821382
for rv_repr in rv_reprs
@@ -1407,8 +1407,8 @@ def _str_repr(self, formatting="plain", **kwargs):
14071407
def __str__(self, **kwargs):
14081408
return self._str_repr(formatting="plain", **kwargs)
14091409

1410-
def _repr_latex_(self, **kwargs):
1411-
return self._str_repr(formatting="latex", **kwargs)
1410+
def _repr_latex_(self, *, formatting="latex", **kwargs):
1411+
return self._str_repr(formatting=formatting, **kwargs)
14121412

14131413
__latex__ = _repr_latex_
14141414

@@ -1874,24 +1874,27 @@ def _walk_up_rv(rv, formatting="plain"):
18741874
all_rvs.extend(_walk_up_rv(parent, formatting=formatting))
18751875
else:
18761876
name = rv.name if rv.name else "Constant"
1877-
fmt = r"\text{{{name}}}" if formatting == "latex" else "{name}"
1877+
fmt = r"\text{{{name}}}" if "latex" in formatting else "{name}"
18781878
all_rvs.append(fmt.format(name=name))
18791879
return all_rvs
18801880

18811881

18821882
class DeterministicWrapper(tt.TensorVariable):
18831883
def _str_repr(self, formatting="plain"):
1884-
if formatting == "latex":
1885-
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
1886-
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
1887-
)
1884+
if "latex" in formatting:
1885+
if formatting == "latex_with_params":
1886+
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
1887+
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
1888+
)
1889+
return fr"$\text{{{self.name}}} \sim \text{{Deterministic}}$"
18881890
else:
1889-
return "{name} ~ Deterministic({args})".format(
1890-
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting))
1891-
)
1891+
if formatting == "plain_with_params":
1892+
args = ", ".join(_walk_up_rv(self, formatting=formatting))
1893+
return f"{self.name} ~ Deterministic({args})"
1894+
return f"{self.name} ~ Deterministic"
18921895

1893-
def _repr_latex_(self):
1894-
return self._str_repr(formatting="latex")
1896+
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
1897+
return self._str_repr(formatting=formatting)
18951898

18961899
__latex__ = _repr_latex_
18971900

pymc3/model_graph.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def update_input_map(key: str, val: Set[VarName]):
121121
pass
122122
return input_map
123123

124-
def _make_node(self, var_name, graph):
124+
def _make_node(self, var_name, graph, *, formatting: str = "plain"):
125125
"""Attaches the given variable to a graphviz Digraph"""
126126
v = self.model[var_name]
127127

@@ -146,7 +146,7 @@ def _make_node(self, var_name, graph):
146146
elif isinstance(v, SharedVariable):
147147
label = f"{var_name}\n~\nData"
148148
else:
149-
label = str(v).replace(" ~ ", "\n~\n")
149+
label = v._str_repr(formatting=formatting).replace(" ~ ", "\n~\n")
150150

151151
graph.node(var_name.replace(":", "&"), label, **attrs)
152152

@@ -181,7 +181,7 @@ def get_plates(self):
181181
plates[shape].add(var_name)
182182
return plates
183183

184-
def make_graph(self):
184+
def make_graph(self, formatting: str = "plain"):
185185
"""Make graphviz Digraph of PyMC3 model
186186
187187
Returns
@@ -205,20 +205,20 @@ def make_graph(self):
205205
# must be preceded by 'cluster' to get a box around it
206206
with graph.subgraph(name="cluster" + label) as sub:
207207
for var_name in var_names:
208-
self._make_node(var_name, sub)
208+
self._make_node(var_name, sub, formatting=formatting)
209209
# plate label goes bottom right
210210
sub.attr(label=label, labeljust="r", labelloc="b", style="rounded")
211211
else:
212212
for var_name in var_names:
213-
self._make_node(var_name, graph)
213+
self._make_node(var_name, graph, formatting=formatting)
214214

215215
for key, values in self.make_compute_graph().items():
216216
for value in values:
217217
graph.edge(value.replace(":", "&"), key.replace(":", "&"))
218218
return graph
219219

220220

221-
def model_to_graphviz(model=None):
221+
def model_to_graphviz(model=None, *, formatting: str = "plain"):
222222
"""Produce a graphviz Digraph from a PyMC3 model.
223223
224224
Requires graphviz, which may be installed most easily with
@@ -228,6 +228,15 @@ def model_to_graphviz(model=None):
228228
and then `pip install graphviz` to get the python bindings. See
229229
http://graphviz.readthedocs.io/en/stable/manual.html
230230
for more information.
231+
232+
Parameters
233+
----------
234+
model : pm.Model
235+
The model to plot. Not required when called from inside a modelcontext.
236+
formatting : str
237+
one of { "plain", "plain_with_params" }
231238
"""
239+
if not "plain" in formatting:
240+
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
232241
model = pm.modelcontext(model)
233-
return ModelGraph(model).make_graph()
242+
return ModelGraph(model).make_graph(formatting=formatting)

pymc3/sampling.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
PGBART,
5555
)
5656
from .util import (
57+
check_start_vals,
5758
update_start_vals,
5859
get_untransformed_name,
5960
is_transformed_name,
@@ -419,7 +420,16 @@ def sample(
419420
420421
"""
421422
model = modelcontext(model)
423+
if start is None:
424+
start = model.test_point
425+
else:
426+
if isinstance(start, dict):
427+
update_start_vals(start, model.test_point, model)
428+
else:
429+
for chain_start_vals in start:
430+
update_start_vals(chain_start_vals, model.test_point, model)
422431

432+
check_start_vals(start, model)
423433
if cores is None:
424434
cores = min(4, _cpu_count())
425435

@@ -487,6 +497,7 @@ def sample(
487497
progressbar=progressbar,
488498
**kwargs,
489499
)
500+
check_start_vals(start_, model)
490501
if start is None:
491502
start = start_
492503
except (AttributeError, NotImplementedError, tg.NullTypeGradError):

0 commit comments

Comments
 (0)