Skip to content

Commit ada368c

Browse files
authored
Merge pull request #31 from pymc-devs/master
Sync Fork from Upstream Repo
2 parents 7454665 + fd72bd4 commit ada368c

File tree

5 files changed

+393
-27
lines changed

5 files changed

+393
-27
lines changed

RELEASE-NOTES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
- use [fastprogress](https://github.com/fastai/fastprogress) instead of tqdm [#3693](https://github.com/pymc-devs/pymc3/pull/3693)
77
- `DEMetropolis` can now tune both `lambda` and `scaling` parameters, but by default neither of them are tuned. See [#3743](https://github.com/pymc-devs/pymc3/pull/3743) for more info.
88

9+
### Maintenance
10+
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.
11+
912
## PyMC3 3.8 (November 29 2019)
1013

1114
### New features

docs/source/notebooks/BEST.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,9 @@
303303
"cell_type": "markdown",
304304
"metadata": {},
305305
"source": [
306-
"Looking at the group differences, we can conclude that there are meaningful differences between the two groups for all three measures. For these comparisons, it is useful to use zero as a reference value (`ref_val`); providing this reference value yields cumulative probabilities for the posterior distribution on either side of the value. Thus, for the difference in means, 99.4% of the posterior probability is greater than zero, which suggests the group means are credibly different. The effect size and differences in standard deviation are similarly positive.\n",
306+
"Looking at the group differences below, we can conclude that there are meaningful differences between the two groups for all three measures. For these comparisons, it is useful to use zero as a reference value (`ref_val`); providing this reference value yields cumulative probabilities for the posterior distribution on either side of the value. Thus, for the difference of means, at least 97% of the posterior probability are greater than zero, which suggests the group means are credibly different. The effect size and differences in standard deviation are similarly positive.\n",
307307
"\n",
308-
"These estimates suggest that the \"smart drug\" increased both the expected scores, but also the variability in scores across the sample. So, this does not rule out the possibility that some recipients may be adversely affected by the drug at the same time others benefit."
308+
"These estimates suggest that the \"smart drug\" increased both the expected scores, but also the variability in scores across the sample. So, this does not rule out the possibility that some recipients may be adversely affected by the drug at the same time others benefit."
309309
]
310310
},
311311
{
@@ -515,7 +515,7 @@
515515
"name": "python",
516516
"nbconvert_exporter": "python",
517517
"pygments_lexer": "ipython3",
518-
"version": "3.7.2"
518+
"version": "3.7.3"
519519
},
520520
"latex_envs": {
521521
"bibliofile": "biblio.bib",
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Sample callback\n",
8+
"\n",
9+
"This notebook demonstrates the usage of the callback attribute in `pm.sample`. A callback is a function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw as arguments and will contain all samples for a single trace.\n",
10+
"\n",
11+
"The sampling process can be interrupted by throwing a `KeyboardInterrupt` from inside the callback.\n",
12+
"\n",
13+
"use-cases for this callback include:\n",
14+
"\n",
15+
" - Stopping sampling when a number of effective samples is reached\n",
16+
" - Stopping sampling when there are too many divergences\n",
17+
" - Logging metrics to external tools (such as TensorBoard)\n",
18+
" \n",
19+
"We'll start with defining a simple model"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 1,
25+
"metadata": {},
26+
"outputs": [],
27+
"source": [
28+
"import pymc3 as pm\n",
29+
"import numpy as np\n",
30+
"\n",
31+
"X = np.array([1, 2, 3, 4, 5])\n",
32+
"y = X * 2 + np.random.randn(len(X))\n",
33+
"with pm.Model() as model:\n",
34+
" \n",
35+
" intercept = pm.Normal('intercept', 0, 10)\n",
36+
" slope = pm.Normal('slope', 0, 10)\n",
37+
" \n",
38+
" mean = intercept + slope * X\n",
39+
" error = pm.HalfCauchy('error', 1)\n",
40+
" obs = pm.Normal('obs', mean, error, observed=y)\n",
41+
" "
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"metadata": {},
47+
"source": [
48+
"We can then for example add a callback that stops sampling whenever 100 samples are made, regardless of the number of draws set in the `pm.sample`"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": 9,
54+
"metadata": {},
55+
"outputs": [
56+
{
57+
"name": "stderr",
58+
"output_type": "stream",
59+
"text": [
60+
"Auto-assigning NUTS sampler...\n",
61+
"Initializing NUTS using jitter+adapt_diag...\n",
62+
"Sequential sampling (1 chains in 1 job)\n",
63+
"NUTS: [error, slope, intercept]\n"
64+
]
65+
},
66+
{
67+
"data": {
68+
"text/html": [
69+
"\n",
70+
" <div>\n",
71+
" <style>\n",
72+
" /* Turns off some styling */\n",
73+
" progress {\n",
74+
" /* gets rid of default border in Firefox and Opera. */\n",
75+
" border: none;\n",
76+
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
77+
" background-size: auto;\n",
78+
" }\n",
79+
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
80+
" background: #F44336;\n",
81+
" }\n",
82+
" </style>\n",
83+
" <progress value='0' class='progress-bar-interrupted' max='500', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
84+
" Interrupted\n",
85+
" </div>\n",
86+
" "
87+
],
88+
"text/plain": [
89+
"<IPython.core.display.HTML object>"
90+
]
91+
},
92+
"metadata": {},
93+
"output_type": "display_data"
94+
},
95+
{
96+
"name": "stderr",
97+
"output_type": "stream",
98+
"text": [
99+
"There were 12 divergences after tuning. Increase `target_accept` or reparameterize.\n",
100+
"The acceptance probability does not match the target. It is 0.5303940121554945, but should be close to 0.8. Try to increase the number of tuning steps.\n",
101+
"Only one chain was sampled, this makes it impossible to run some convergence checks\n"
102+
]
103+
},
104+
{
105+
"name": "stdout",
106+
"output_type": "stream",
107+
"text": [
108+
"100\n"
109+
]
110+
}
111+
],
112+
"source": [
113+
"\n",
114+
"def my_callback(trace, draw):\n",
115+
" if len(trace) >= 100:\n",
116+
" raise KeyboardInterrupt()\n",
117+
" \n",
118+
"with model:\n",
119+
" trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)\n",
120+
" \n",
121+
"print(len(trace))"
122+
]
123+
},
124+
{
125+
"cell_type": "markdown",
126+
"metadata": {},
127+
"source": [
128+
"Something to note though, is that the trace we get passed in the callback only correspond to a single chain. That means that if we want to do calculations over multiple chains at once, we'll need a bit of machinery to make this possible."
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 20,
134+
"metadata": {},
135+
"outputs": [
136+
{
137+
"name": "stderr",
138+
"output_type": "stream",
139+
"text": [
140+
"Auto-assigning NUTS sampler...\n",
141+
"Initializing NUTS using jitter+adapt_diag...\n",
142+
"Multiprocess sampling (2 chains in 2 jobs)\n",
143+
"NUTS: [error, slope, intercept]\n"
144+
]
145+
},
146+
{
147+
"data": {
148+
"text/html": [
149+
"\n",
150+
" <div>\n",
151+
" <style>\n",
152+
" /* Turns off some styling */\n",
153+
" progress {\n",
154+
" /* gets rid of default border in Firefox and Opera. */\n",
155+
" border: none;\n",
156+
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
157+
" background-size: auto;\n",
158+
" }\n",
159+
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
160+
" background: #F44336;\n",
161+
" }\n",
162+
" </style>\n",
163+
" <progress value='1000' class='' max='1000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
164+
" 100.00% [1000/1000 00:00<00:00 Sampling 2 chains, 518 divergences]\n",
165+
" </div>\n",
166+
" "
167+
],
168+
"text/plain": [
169+
"<IPython.core.display.HTML object>"
170+
]
171+
},
172+
"metadata": {},
173+
"output_type": "display_data"
174+
},
175+
{
176+
"name": "stdout",
177+
"output_type": "stream",
178+
"text": [
179+
"100\n",
180+
"200\n",
181+
"100\n",
182+
"300\n",
183+
"400\n",
184+
"200\n",
185+
"500\n",
186+
"300\n",
187+
"400\n",
188+
"500\n"
189+
]
190+
},
191+
{
192+
"name": "stderr",
193+
"output_type": "stream",
194+
"text": [
195+
"The chain contains only diverging samples. The model is probably misspecified.\n",
196+
"The acceptance probability does not match the target. It is 0.0, but should be close to 0.8. Try to increase the number of tuning steps.\n",
197+
"There were 18 divergences after tuning. Increase `target_accept` or reparameterize.\n",
198+
"The acceptance probability does not match the target. It is 9.211751427765233e-155, but should be close to 0.8. Try to increase the number of tuning steps.\n",
199+
"The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.\n",
200+
"The estimated number of effective samples is smaller than 200 for some parameters.\n"
201+
]
202+
}
203+
],
204+
"source": [
205+
"def my_callback(trace, draw):\n",
206+
" if len(trace) % 100 == 0:\n",
207+
" print(len(trace))\n",
208+
" \n",
209+
"with model:\n",
210+
" trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)\n",
211+
" "
212+
]
213+
},
214+
{
215+
"cell_type": "markdown",
216+
"metadata": {},
217+
"source": [
218+
"We can use the `draw.chain` attribute to figure out which chain the current draw and trace belong to. Combined with some kind of convergence statistic like r_hat we can stop when we have converged, regardless of the amount of specified draws."
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": 128,
224+
"metadata": {},
225+
"outputs": [
226+
{
227+
"name": "stderr",
228+
"output_type": "stream",
229+
"text": [
230+
"Auto-assigning NUTS sampler...\n",
231+
"Initializing NUTS using jitter+adapt_diag...\n",
232+
"Multiprocess sampling (2 chains in 2 jobs)\n",
233+
"NUTS: [error, slope, intercept]\n"
234+
]
235+
},
236+
{
237+
"data": {
238+
"text/html": [
239+
"\n",
240+
" <div>\n",
241+
" <style>\n",
242+
" /* Turns off some styling */\n",
243+
" progress {\n",
244+
" /* gets rid of default border in Firefox and Opera. */\n",
245+
" border: none;\n",
246+
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
247+
" background-size: auto;\n",
248+
" }\n",
249+
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
250+
" background: #F44336;\n",
251+
" }\n",
252+
" </style>\n",
253+
" <progress value='3596' class='' max='202000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
254+
" 1.78% [3596/202000 00:02<02:10 Sampling 2 chains, 37 divergences]\n",
255+
" </div>\n",
256+
" "
257+
],
258+
"text/plain": [
259+
"<IPython.core.display.HTML object>"
260+
]
261+
},
262+
"metadata": {},
263+
"output_type": "display_data"
264+
},
265+
{
266+
"name": "stderr",
267+
"output_type": "stream",
268+
"text": [
269+
"The estimated number of effective samples is smaller than 200 for some parameters.\n"
270+
]
271+
}
272+
],
273+
"source": [
274+
"import arviz as az\n",
275+
"class MyCallback:\n",
276+
" def __init__(self, every=1000, max_rhat=1.05):\n",
277+
" self.every = every\n",
278+
" self.max_rhat = max_rhat\n",
279+
" self.traces = {}\n",
280+
" \n",
281+
" def __call__(self, trace, draw):\n",
282+
" if draw.tuning:\n",
283+
" return\n",
284+
"\n",
285+
" self.traces[draw.chain] = trace\n",
286+
" if len(trace) % self.every == 0: \n",
287+
" multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))\n",
288+
" if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:\n",
289+
" raise KeyboardInterrupt\n",
290+
"\n",
291+
"with model:\n",
292+
" trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)\n",
293+
"\n"
294+
]
295+
}
296+
],
297+
"metadata": {
298+
"kernelspec": {
299+
"display_name": "Python 3",
300+
"language": "python",
301+
"name": "python3"
302+
},
303+
"language_info": {
304+
"codemirror_mode": {
305+
"name": "ipython",
306+
"version": 3
307+
},
308+
"file_extension": ".py",
309+
"mimetype": "text/x-python",
310+
"name": "python",
311+
"nbconvert_exporter": "python",
312+
"pygments_lexer": "ipython3",
313+
"version": "3.7.4"
314+
}
315+
},
316+
"nbformat": 4,
317+
"nbformat_minor": 2
318+
}

0 commit comments

Comments
 (0)