Skip to content

Commit fd72bd4

Browse files
authored
pm.sample callback (#3737)
* adds a callback to pm.sample() that returns the trace up until that point and whether it is diverging * add tests * replace diverging attribute in callback with draw and remove return value inspection * add notebook describing usage of `callback` in `pm.sample` * add lost commit back in
1 parent db7c493 commit fd72bd4

File tree

3 files changed

+387
-8
lines changed

3 files changed

+387
-8
lines changed
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Sample callback\n",
8+
"\n",
9+
"This notebook demonstrates the usage of the callback attribute in `pm.sample`. A callback is a function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw as arguments and will contain all samples for a single trace.\n",
10+
"\n",
11+
"The sampling process can be interrupted by throwing a `KeyboardInterrupt` from inside the callback.\n",
12+
"\n",
13+
"use-cases for this callback include:\n",
14+
"\n",
15+
" - Stopping sampling when a number of effective samples is reached\n",
16+
" - Stopping sampling when there are too many divergences\n",
17+
" - Logging metrics to external tools (such as TensorBoard)\n",
18+
" \n",
19+
"We'll start with defining a simple model"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 1,
25+
"metadata": {},
26+
"outputs": [],
27+
"source": [
28+
"import pymc3 as pm\n",
29+
"import numpy as np\n",
30+
"\n",
31+
"X = np.array([1, 2, 3, 4, 5])\n",
32+
"y = X * 2 + np.random.randn(len(X))\n",
33+
"with pm.Model() as model:\n",
34+
" \n",
35+
" intercept = pm.Normal('intercept', 0, 10)\n",
36+
" slope = pm.Normal('slope', 0, 10)\n",
37+
" \n",
38+
" mean = intercept + slope * X\n",
39+
" error = pm.HalfCauchy('error', 1)\n",
40+
" obs = pm.Normal('obs', mean, error, observed=y)\n",
41+
" "
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"metadata": {},
47+
"source": [
48+
"We can then for example add a callback that stops sampling whenever 100 samples are made, regardless of the number of draws set in the `pm.sample`"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": 9,
54+
"metadata": {},
55+
"outputs": [
56+
{
57+
"name": "stderr",
58+
"output_type": "stream",
59+
"text": [
60+
"Auto-assigning NUTS sampler...\n",
61+
"Initializing NUTS using jitter+adapt_diag...\n",
62+
"Sequential sampling (1 chains in 1 job)\n",
63+
"NUTS: [error, slope, intercept]\n"
64+
]
65+
},
66+
{
67+
"data": {
68+
"text/html": [
69+
"\n",
70+
" <div>\n",
71+
" <style>\n",
72+
" /* Turns off some styling */\n",
73+
" progress {\n",
74+
" /* gets rid of default border in Firefox and Opera. */\n",
75+
" border: none;\n",
76+
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
77+
" background-size: auto;\n",
78+
" }\n",
79+
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
80+
" background: #F44336;\n",
81+
" }\n",
82+
" </style>\n",
83+
" <progress value='0' class='progress-bar-interrupted' max='500', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
84+
" Interrupted\n",
85+
" </div>\n",
86+
" "
87+
],
88+
"text/plain": [
89+
"<IPython.core.display.HTML object>"
90+
]
91+
},
92+
"metadata": {},
93+
"output_type": "display_data"
94+
},
95+
{
96+
"name": "stderr",
97+
"output_type": "stream",
98+
"text": [
99+
"There were 12 divergences after tuning. Increase `target_accept` or reparameterize.\n",
100+
"The acceptance probability does not match the target. It is 0.5303940121554945, but should be close to 0.8. Try to increase the number of tuning steps.\n",
101+
"Only one chain was sampled, this makes it impossible to run some convergence checks\n"
102+
]
103+
},
104+
{
105+
"name": "stdout",
106+
"output_type": "stream",
107+
"text": [
108+
"100\n"
109+
]
110+
}
111+
],
112+
"source": [
113+
"\n",
114+
"def my_callback(trace, draw):\n",
115+
" if len(trace) >= 100:\n",
116+
" raise KeyboardInterrupt()\n",
117+
" \n",
118+
"with model:\n",
119+
" trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)\n",
120+
" \n",
121+
"print(len(trace))"
122+
]
123+
},
124+
{
125+
"cell_type": "markdown",
126+
"metadata": {},
127+
"source": [
128+
"Something to note though, is that the trace we get passed in the callback only correspond to a single chain. That means that if we want to do calculations over multiple chains at once, we'll need a bit of machinery to make this possible."
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 20,
134+
"metadata": {},
135+
"outputs": [
136+
{
137+
"name": "stderr",
138+
"output_type": "stream",
139+
"text": [
140+
"Auto-assigning NUTS sampler...\n",
141+
"Initializing NUTS using jitter+adapt_diag...\n",
142+
"Multiprocess sampling (2 chains in 2 jobs)\n",
143+
"NUTS: [error, slope, intercept]\n"
144+
]
145+
},
146+
{
147+
"data": {
148+
"text/html": [
149+
"\n",
150+
" <div>\n",
151+
" <style>\n",
152+
" /* Turns off some styling */\n",
153+
" progress {\n",
154+
" /* gets rid of default border in Firefox and Opera. */\n",
155+
" border: none;\n",
156+
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
157+
" background-size: auto;\n",
158+
" }\n",
159+
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
160+
" background: #F44336;\n",
161+
" }\n",
162+
" </style>\n",
163+
" <progress value='1000' class='' max='1000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
164+
" 100.00% [1000/1000 00:00<00:00 Sampling 2 chains, 518 divergences]\n",
165+
" </div>\n",
166+
" "
167+
],
168+
"text/plain": [
169+
"<IPython.core.display.HTML object>"
170+
]
171+
},
172+
"metadata": {},
173+
"output_type": "display_data"
174+
},
175+
{
176+
"name": "stdout",
177+
"output_type": "stream",
178+
"text": [
179+
"100\n",
180+
"200\n",
181+
"100\n",
182+
"300\n",
183+
"400\n",
184+
"200\n",
185+
"500\n",
186+
"300\n",
187+
"400\n",
188+
"500\n"
189+
]
190+
},
191+
{
192+
"name": "stderr",
193+
"output_type": "stream",
194+
"text": [
195+
"The chain contains only diverging samples. The model is probably misspecified.\n",
196+
"The acceptance probability does not match the target. It is 0.0, but should be close to 0.8. Try to increase the number of tuning steps.\n",
197+
"There were 18 divergences after tuning. Increase `target_accept` or reparameterize.\n",
198+
"The acceptance probability does not match the target. It is 9.211751427765233e-155, but should be close to 0.8. Try to increase the number of tuning steps.\n",
199+
"The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.\n",
200+
"The estimated number of effective samples is smaller than 200 for some parameters.\n"
201+
]
202+
}
203+
],
204+
"source": [
205+
"def my_callback(trace, draw):\n",
206+
" if len(trace) % 100 == 0:\n",
207+
" print(len(trace))\n",
208+
" \n",
209+
"with model:\n",
210+
" trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)\n",
211+
" "
212+
]
213+
},
214+
{
215+
"cell_type": "markdown",
216+
"metadata": {},
217+
"source": [
218+
"We can use the `draw.chain` attribute to figure out which chain the current draw and trace belong to. Combined with some kind of convergence statistic like r_hat we can stop when we have converged, regardless of the amount of specified draws."
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": 128,
224+
"metadata": {},
225+
"outputs": [
226+
{
227+
"name": "stderr",
228+
"output_type": "stream",
229+
"text": [
230+
"Auto-assigning NUTS sampler...\n",
231+
"Initializing NUTS using jitter+adapt_diag...\n",
232+
"Multiprocess sampling (2 chains in 2 jobs)\n",
233+
"NUTS: [error, slope, intercept]\n"
234+
]
235+
},
236+
{
237+
"data": {
238+
"text/html": [
239+
"\n",
240+
" <div>\n",
241+
" <style>\n",
242+
" /* Turns off some styling */\n",
243+
" progress {\n",
244+
" /* gets rid of default border in Firefox and Opera. */\n",
245+
" border: none;\n",
246+
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
247+
" background-size: auto;\n",
248+
" }\n",
249+
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
250+
" background: #F44336;\n",
251+
" }\n",
252+
" </style>\n",
253+
" <progress value='3596' class='' max='202000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
254+
" 1.78% [3596/202000 00:02<02:10 Sampling 2 chains, 37 divergences]\n",
255+
" </div>\n",
256+
" "
257+
],
258+
"text/plain": [
259+
"<IPython.core.display.HTML object>"
260+
]
261+
},
262+
"metadata": {},
263+
"output_type": "display_data"
264+
},
265+
{
266+
"name": "stderr",
267+
"output_type": "stream",
268+
"text": [
269+
"The estimated number of effective samples is smaller than 200 for some parameters.\n"
270+
]
271+
}
272+
],
273+
"source": [
274+
"import arviz as az\n",
275+
"class MyCallback:\n",
276+
" def __init__(self, every=1000, max_rhat=1.05):\n",
277+
" self.every = every\n",
278+
" self.max_rhat = max_rhat\n",
279+
" self.traces = {}\n",
280+
" \n",
281+
" def __call__(self, trace, draw):\n",
282+
" if draw.tuning:\n",
283+
" return\n",
284+
"\n",
285+
" self.traces[draw.chain] = trace\n",
286+
" if len(trace) % self.every == 0: \n",
287+
" multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))\n",
288+
" if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:\n",
289+
" raise KeyboardInterrupt\n",
290+
"\n",
291+
"with model:\n",
292+
" trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)\n",
293+
"\n"
294+
]
295+
}
296+
],
297+
"metadata": {
298+
"kernelspec": {
299+
"display_name": "Python 3",
300+
"language": "python",
301+
"name": "python3"
302+
},
303+
"language_info": {
304+
"codemirror_mode": {
305+
"name": "ipython",
306+
"version": 3
307+
},
308+
"file_extension": ".py",
309+
"mimetype": "text/x-python",
310+
"name": "python",
311+
"nbconvert_exporter": "python",
312+
"pygments_lexer": "ipython3",
313+
"version": "3.7.4"
314+
}
315+
},
316+
"nbformat": 4,
317+
"nbformat_minor": 2
318+
}

0 commit comments

Comments
 (0)