Skip to content

Commit d2bca90

Browse files
davidbrocharttwiecki
authored andcommitted
Added live_traceplot function (#1934)
* Added live_traceplot function * Cosmetic change * Changed the API to pm.sample(..., live_plot=True) * Don't include `-np.inf` in calculating average ELBO (#1880) * Adds an infmean for advi reporting * fixing typo * Add tutorial to detect sampling problems (#1866) * Expand sampler-stats.ipynb example include model diagnose from case study example in Stan http://mc-stan.org/documentation/case-studies/divergences_and_bias.html * Sampler Diagnose for NUTS * descriptive annotation and axis labels * Fix typos * PEP8 styling * minor updates 1, add example to examples.rst 2, original content in Markdown code block * Make install scripts idempotent (#1879) * DOC Change heading names. * Add examples of censored data models (#1870) * Raise TypeError on non-data values of observed (#1872) * Raise TypeError on non-data values of observed * Added check for observed TypeError * Make exponential mode have the correct shape * Fix support of LKJCorr * Added tutorial notebook on updating priors * Fixed y-axis bug in forestplot; added transform argument to summary * Style cleanup * Made small changes and executed the notebook * Added probit and invprobit functions * Added carriage return to end of file * Fixed indentation * Changed probit test to use assert_allclose * Fix tests for LKJCorr * Added warning for ignoring init arguments in sample * Kill stray tab * Improve performance of transformations * DOC Add new features * Bump version. * Added docs and scripts to MANIFEST * WIP: Implement opvi (#1694) * migrate useful functions from previous PR (cherry picked from commit 9f61ab4) * opvi draft (cherry picked from commit d0997ff) * made some test work (cherry picked from commit b1a87d5) * refactored approximation to support aevb (without test) * refactor opvi delete unnecessary methods from operator, change method order * change log_q_local computation * add full rank approximation * add more_params argument to ObjectiveFunction.updates (aevb case) * refactor density computation in full rank approximation * typo: cast dict values to list * typo: cast dict values to list * typo: undefined T in dist_math * refactor gradient scaling as suggested in approximateinference.org/accepted/RoederEtAl2016.pdf * implement Langevin-Stein (LS) operator * fix docstring * add blank line in docs * refactor ObjectiveFunction * add not working LS Op test * experiments with not working LS Op * change activations * refactor networks * add step_function * remove Langevin Stein, done refactoring * remove Langevin Stein, done refactoring * change optimizers * refactor init params * implement tests * implement Inference * code style * test fix * add minibatch test (fails now) * add more tests for minibatch training * add logdet to FullRank approximation * add conversion of arrays to floatX * tiny changes * change number of iterations * fix test and pylint check * memoize functions in Objective function * Optimize code a lot * a bit more efficient pickling * add docs * Add MeanField -> FullRank parameter transfer * refactor MeanField and FullRank a bit * fix FullRank bug with shapes in random * refactor Model.flatten (CC @taku-y) * add `approximate` to inference * rename approximate->fit * change abbreviations * Fix bug with scaling input variable in aevb * fix theano bottleneck in graph * more efficient scaling for local vars * fix typo in local Q * add aevb test * refactor memoize to work with my objects * add tests for numpy view usage * pickle-hash fix * pickle-hash fix again * add node sampling + make up some code * add notebook with example * sample_proba explained * Revert "small fix for multivariate mixture models" * Added message about init only working with auto-assigned step methods * doc(DiagInferDiv): formatting fix in blog post quote. Closes #1895. (#1909) * delete unnecessary text and add some benchmarks (#1901) * Add LKJCholeskyCov * Added newline to MANIFEST * Replaced package list with find_packages in setup.py; removed examples/data/__init__.py * Fix log jacobian in LKJCholeskyCov * Updated version to rc2 * Fixed stray version string * Fix indexing traces with steps greater one * refactor variational module, add histogram approximation (#1904) * refactor module, add histogram * add more tests * refactor some code concerning AEVB histogram * fix test for histogram * use mean as deterministic point in Histogram * remove unused import * change names of shortcuts * add names to shared params * add new line at the end of `approximations.py` * Add documentation for LKJCholeskyCov * SVGD problems (#1916) * fix some svgd problems * switch -> ifelse * except in record * Histogram docs (#1914) * add docs * delete redundant code * add usage example * remove unused import * Add expand_packed_triangular * improve aesthetics * Bump theano to 0.9.0rc4 (#1921) * Add tests for LKJCholeskyCov * Histogram: use only free RVs from trace (#1926) * use only free RVs from trace * use memoize in Histogram.histogram_logp * Change tests for histogram * Bump theano to be at least 0.9.0 * small fix to prevent a TypeError with the ufunc true_divide * Fix tests for py2 * Add floatX wrappers in test_advi * Changed the API to pm.sample(..., live_plot=True) * Better formatting
1 parent e6cd229 commit d2bca90

File tree

3 files changed

+181
-7
lines changed

3 files changed

+181
-7
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"deletable": true,
7+
"editable": true
8+
},
9+
"source": [
10+
"# Live sample plots"
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"metadata": {
16+
"deletable": true,
17+
"editable": true
18+
},
19+
"source": [
20+
"This notebook illustrates how we can have live sample plots when calling the `sample` function with `live_plot=True`. It is based on the \"Coal mining disasters\" case study in the [Getting started notebook](https://github.com/pymc-devs/pymc3/blob/master/docs/source/notebooks/getting_started.ipynb)."
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"metadata": {
27+
"collapsed": false,
28+
"deletable": true,
29+
"editable": true
30+
},
31+
"outputs": [],
32+
"source": [
33+
"import numpy as np\n",
34+
"from pymc3 import Model, Exponential, DiscreteUniform, Poisson, sample\n",
35+
"from pymc3.math import switch\n",
36+
"\n",
37+
"%matplotlib notebook"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"metadata": {
44+
"collapsed": false,
45+
"deletable": true,
46+
"editable": true
47+
},
48+
"outputs": [],
49+
"source": [
50+
"disaster_data = np.ma.masked_values([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,\n",
51+
" 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,\n",
52+
" 2, 2, 3, 4, 2, 1, 3, -999, 2, 1, 1, 1, 1, 3, 0, 0,\n",
53+
" 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,\n",
54+
" 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,\n",
55+
" 3, 3, 1, -999, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,\n",
56+
" 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], value=-999)\n",
57+
"year = np.arange(1851, 1962)"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"metadata": {
64+
"collapsed": false,
65+
"deletable": true,
66+
"editable": true
67+
},
68+
"outputs": [],
69+
"source": [
70+
"with Model() as disaster_model:\n",
71+
"\n",
72+
" switchpoint = DiscreteUniform('switchpoint', lower=year.min(), upper=year.max(), testval=1900)\n",
73+
"\n",
74+
" # Priors for pre- and post-switch rates number of disasters\n",
75+
" early_rate = Exponential('early_rate', 1)\n",
76+
" late_rate = Exponential('late_rate', 1)\n",
77+
"\n",
78+
" # Allocate appropriate Poisson rates to years before and after current\n",
79+
" rate = switch(switchpoint >= year, early_rate, late_rate)\n",
80+
"\n",
81+
" disasters = Poisson('disasters', rate, observed=disaster_data)"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"metadata": {
88+
"collapsed": false,
89+
"deletable": true,
90+
"editable": true,
91+
"scrolled": false
92+
},
93+
"outputs": [],
94+
"source": [
95+
"with disaster_model:\n",
96+
" trace = sample(10000, live_plot=True, skip_first=100, refresh_every=300, roll_over=1000)"
97+
]
98+
}
99+
],
100+
"metadata": {
101+
"anaconda-cloud": {},
102+
"kernelspec": {
103+
"display_name": "Python [default]",
104+
"language": "python",
105+
"name": "python3"
106+
},
107+
"language_info": {
108+
"codemirror_mode": {
109+
"name": "ipython",
110+
"version": 3
111+
},
112+
"file_extension": ".py",
113+
"mimetype": "text/x-python",
114+
"name": "python",
115+
"nbconvert_exporter": "python",
116+
"pygments_lexer": "ipython3",
117+
"version": "3.5.2"
118+
}
119+
},
120+
"nbformat": 4,
121+
"nbformat_minor": 2
122+
}

pymc3/plots/traceplot.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
def traceplot(trace, varnames=None, transform=identity_transform, figsize=None, lines=None,
99
combined=False, plot_transformed=False, grid=False, alpha=0.35, priors=None,
10-
prior_alpha=1, prior_style='--', ax=None):
10+
prior_alpha=1, prior_style='--', ax=None, live_plot=False,
11+
skip_first=0, refresh_every=100, roll_over=1000):
1112
"""Plot samples histograms and values.
1213
1314
Parameters
@@ -45,6 +46,16 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
4546
Line style for prior plot. Defaults to '--' (dashed line).
4647
ax : axes
4748
Matplotlib axes. Accepts an array of axes, e.g.:
49+
live_plot: bool
50+
Flag for updating the current figure while sampling
51+
skip_first : int
52+
Number of first samples not shown in plots (burn-in). This affects
53+
frequency and stream plots.
54+
refresh_every : int
55+
Period of plot updates (in sample number)
56+
roll_over : int
57+
Width of the sliding window for the sample stream plots: last roll_over
58+
samples are shown (no effect on frequency plots).
4859
4960
>>> fig, axs = plt.subplots(3, 2) # 3 RVs
5061
>>> pymc3.traceplot(trace, ax=axs)
@@ -57,6 +68,8 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
5768
ax : matplotlib axes
5869
5970
"""
71+
trace = trace[skip_first:]
72+
6073
if varnames is None:
6174
varnames = get_default_varnames(trace, plot_transformed)
6275

@@ -70,9 +83,23 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
7083
prior = priors[i]
7184
else:
7285
prior = None
86+
first_time = True
7387
for d in trace.get_values(v, combine=combined, squeeze=False):
7488
d = np.squeeze(transform(d))
7589
d = make_2d(d)
90+
d_stream = d
91+
x0 = 0
92+
if live_plot:
93+
x0 = skip_first
94+
if first_time:
95+
ax[i, 0].cla()
96+
ax[i, 1].cla()
97+
first_time = False
98+
if roll_over is not None:
99+
if len(d) >= roll_over:
100+
x0 = len(d) - roll_over + skip_first
101+
d_stream = d[-roll_over:]
102+
width = len(d_stream)
76103
if d.dtype.kind == 'i':
77104
hist_objs = histplot_op(ax[i, 0], d, alpha=alpha)
78105
colors = [h[-1][0].get_facecolor() for h in hist_objs]
@@ -82,7 +109,7 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
82109
ax[i, 0].set_title(str(v))
83110
ax[i, 0].grid(grid)
84111
ax[i, 1].set_title(str(v))
85-
ax[i, 1].plot(d, alpha=alpha)
112+
ax[i, 1].plot(range(x0, x0 + width), d_stream, alpha=alpha)
86113

87114
ax[i, 0].set_ylabel("Frequency")
88115
ax[i, 1].set_ylabel("Sample value")
@@ -103,6 +130,13 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
103130
lw=1.5, alpha=alpha)
104131
except KeyError:
105132
pass
133+
if live_plot:
134+
for j in [0, 1]:
135+
ax[i, j].relim()
136+
ax[i, j].autoscale_view(True, True, True)
137+
ax[i, 1].set_xlim(x0, x0 + width)
106138
ax[i, 0].set_ylim(ymin=0)
139+
if live_plot:
140+
ax[0, 0].figure.canvas.draw()
107141
plt.tight_layout()
108142
return ax

pymc3/sampling.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
1212
BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
1313
Slice, CompoundStep)
14+
from .plots.utils import identity_transform
15+
from .plots.traceplot import traceplot
1416
from tqdm import tqdm
1517

1618
import warnings
@@ -85,7 +87,7 @@ def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropol
8587

8688
def sample(draws, step=None, init='ADVI', n_init=200000, start=None,
8789
trace=None, chain=0, njobs=1, tune=None, progressbar=True,
88-
model=None, random_seed=-1):
90+
model=None, random_seed=-1, live_plot=False, **kwargs):
8991
"""Draw samples from the posterior using the given step methods.
9092
9193
Multiple step methods are supported via compound step methods.
@@ -141,6 +143,8 @@ def sample(draws, step=None, init='ADVI', n_init=200000, start=None,
141143
model : Model (optional if in `with` context)
142144
random_seed : int or list of ints
143145
A list is accepted if more if `njobs` is greater than one.
146+
live_plot: bool
147+
Flag for live plotting the trace while sampling
144148
145149
Returns
146150
-------
@@ -175,7 +179,9 @@ def sample(draws, step=None, init='ADVI', n_init=200000, start=None,
175179
'tune': tune,
176180
'progressbar': progressbar,
177181
'model': model,
178-
'random_seed': random_seed}
182+
'random_seed': random_seed,
183+
'live_plot': live_plot,
184+
**kwargs}
179185

180186
if njobs > 1:
181187
sample_func = _mp_sample
@@ -187,15 +193,27 @@ def sample(draws, step=None, init='ADVI', n_init=200000, start=None,
187193

188194

189195
def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
190-
progressbar=True, model=None, random_seed=-1):
196+
progressbar=True, model=None, random_seed=-1, live_plot=False,
197+
**kwargs):
198+
live_plot_args = {'skip_first': 0, 'refresh_every': 100}
199+
live_plot_args = {arg: kwargs[arg] if arg in kwargs else live_plot_args[arg] for arg in live_plot_args}
200+
skip_first = live_plot_args['skip_first']
201+
refresh_every = live_plot_args['refresh_every']
202+
191203
sampling = _iter_sample(draws, step, start, trace, chain,
192204
tune, model, random_seed)
193205
if progressbar:
194206
sampling = tqdm(sampling, total=draws)
195207
try:
196208
strace = None
197-
for strace in sampling:
198-
pass
209+
for it, strace in enumerate(sampling):
210+
if live_plot:
211+
if it >= skip_first:
212+
trace = MultiTrace([strace])
213+
if it == skip_first:
214+
ax = traceplot(trace, live_plot=False, **kwargs)
215+
elif (it - skip_first) % refresh_every == 0 or it == draws - 1:
216+
traceplot(trace, ax=ax, live_plot=True, **kwargs)
199217
except KeyboardInterrupt:
200218
pass
201219
finally:

0 commit comments

Comments
 (0)