You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
:author: Joshua Cook updated by Tyler James Burch, Chris Fonnesbeck
20
20
:::
21
21
22
22
+++ {"tags": []}
@@ -27,22 +27,10 @@ Often, the model we want to fit is not a perfect line between some $x$ and $y$.
27
27
Instead, the parameters of the model are expected to vary over $x$.
28
28
There are multiple ways to handle this situation, one of which is to fit a *spline*.
29
29
The spline is effectively multiple individual lines, each fit to a different section of $x$, that are tied together at their boundaries, often called *knots*.
30
-
Below is an exmaple of how to fit a spline using the Bayesian framework [PyMC3](https://docs.pymc.io).
31
30
32
-
Below is a full working example of how to fit a spline using the probabilitic programming language PyMC3.
33
-
The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}`mcelreath2018statistical`.
34
-
As the book uses [Stan](https://mc-stan.org) (another advanced probabilitistic programming language), the modeling code is primarily taken from the [GitHub repository of the PyMC3 implementation of *Statistical Rethinking*](https://github.com/pymc-devs/resources/blob/master/Rethinking_2/Chp_04.ipynb).
35
-
My contributions are primarily of explanation and additional analyses of the data and results.
31
+
Below is a full working example of how to fit a spline using PyMC. The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}`mcelreath2018statistical`.
36
32
37
-
**Note that this is not a comprehensive review of splines – I primarily focus on the implementation in PyMC3.**
38
-
For more information on this method of non-linear modeling, I suggesting beginning with chapter 7.4 "Regression Splines" of *An Introduction to Statistical Learning* {cite:p}`james2021statisticallearning`.
39
-
40
-
+++
41
-
42
-
## Setup
43
-
44
-
For this example, I employ the standard data science and Bayesian data analysis packages.
45
-
In addition, the ['patsy'](https://patsy.readthedocs.io/en/latest/) library is used to generate the basis for the spline (more on that below).
33
+
For more information on this method of non-linear modeling, I suggesting beginning with [chapter 5 of Bayesian Modeling and Computation in Python](https://bayesiancomputationbook.com/markdown/chp_05.html) {cite:p}`martin2021bayesian`.
46
34
47
35
```{code-cell} ipython3
48
36
from pathlib import Path
@@ -51,8 +39,7 @@ import arviz as az
51
39
import matplotlib.pyplot as plt
52
40
import numpy as np
53
41
import pandas as pd
54
-
import pymc3 as pm
55
-
import statsmodels.api as sm
42
+
import pymc as pm
56
43
57
44
from patsy import dmatrix
58
45
```
@@ -62,14 +49,13 @@ from patsy import dmatrix
62
49
%config InlineBackend.figure_format = "retina"
63
50
64
51
RANDOM_SEED = 8927
65
-
rng = np.random.default_rng(RANDOM_SEED)
66
52
az.style.use("arviz-darkgrid")
67
53
```
68
54
69
55
## Cherry blossom data
70
56
71
-
The data for this example was the number of days (`doy` for "days of year") that the cherry trees were in bloom in each year (`year`).
72
-
Years missing a `doy` were dropped.
57
+
The data for this example is the number of days (`doy` for "days of year") that the cherry trees were in bloom in each year (`year`).
58
+
For convenience, years missing a `doy` were dropped (which is a bad idea to deal with missing data in general!).
73
59
74
60
```{code-cell} ipython3
75
61
try:
@@ -92,11 +78,11 @@ After dropping rows with missing data, there are 827 years with the numbers of d
92
78
blossom_data.shape
93
79
```
94
80
95
-
Below is a plot of the data we will be modeling showing the number of days of bloom per year.
81
+
If we visualize the data, it is clear that there a lot of annual variation, but some evidence for a non-linear trend in bloom days over time.
96
82
97
83
```{code-cell} ipython3
98
84
blossom_data.plot.scatter(
99
-
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
85
+
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Days in bloom"
100
86
);
101
87
```
102
88
@@ -112,15 +98,11 @@ $\qquad a \sim \mathcal{N}(100, 10)$
112
98
$\qquad w \sim \mathcal{N}(0, 10)$
113
99
$\quad \sigma \sim \text{Exp}(1)$
114
100
115
-
The number of days of bloom will be modeled as a normal distribution with mean $\mu$ and standard deviation $\sigma$.
116
-
The mean will be a linear model composed of a y-intercept $a$ and spline defined by the basis $B$ multiplied by the model parameter $w$ with a variable for each region of the basis.
117
-
Both have relatively weak normal priors.
101
+
The number of days of bloom $D$ will be modeled as a normal distribution with mean $\mu$ and standard deviation $\sigma$. In turn, the mean will be a linear model composed of a y-intercept $a$ and spline defined by the basis $B$ multiplied by the model parameter $w$ with a variable for each region of the basis. Both have relatively weak normal priors.
118
102
119
103
### Prepare the spline
120
104
121
-
The spline will have 15 *knots*, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data).
122
-
The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve.
123
-
The knots will be unevenly spaced over the years such that each region will have the same proportion of data.
105
+
The spline will have 15 *knots*, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data). The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve. The knots will be unevenly spaced over the years such that each region will have the same proportion of data.
124
106
125
107
```{code-cell} ipython3
126
108
num_knots = 15
@@ -138,50 +120,7 @@ for knot in knot_list:
138
120
plt.gca().axvline(knot, color="grey", alpha=0.4);
139
121
```
140
122
141
-
Before doing any Bayesian modeling of the spline, we can get an idea of what our model should look like using the lowess modeling from `statsmodels`
142
-
143
-
```{code-cell} ipython3
144
-
blossom_data.plot.scatter(
145
-
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
Finally we can use 'patsy' to create the matrix $B$ that will be the b-spline basis for the regression.
123
+
We can use `patsy` to create the matrix $B$ that will be the b-spline basis for the regression.
185
124
The degree is set to 3 to create a cubic b-spline.
186
125
187
126
```{code-cell} ipython3
@@ -194,10 +133,7 @@ B = dmatrix(
194
133
B
195
134
```
196
135
197
-
The b-spline basis is plotted below, showing the "domain" of each piece of the spline.
198
-
The height of each curve indicates how "influential" the corresponding model covariate (one per spline region) will be on model's "inference" of that region.
199
-
(The quotes are to indicate that these words were chosen to help with interpretation and are not the proper mathematical terms.)
200
-
The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed.
136
+
The b-spline basis is plotted below, showing the *domain* of each piece of the spline. The height of each curve indicates how influential the corresponding model covariate (one per spline region) will be on model's inference of that region. The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed.
A graphical diagram shows the organization of the model parameters (note that this requires the installation of 'python-graphviz' which is easiest in a `conda` virtual environment).
156
+
Finally, the model can be built using PyMC. A graphical diagram shows the organization of the model parameters (note that this requires the installation of `python-graphviz`, which I recommend doing in a `conda` virtual environment).
Now we can analyze the draws from the posterior of the model.
254
182
255
-
### Fit parameters
183
+
### Parameter Estimates
256
184
257
185
Below is a table summarizing the posterior distributions of the model parameters.
258
186
The posteriors of $a$ and $\sigma$ are quite narrow while those for $w$ are wider.
@@ -261,25 +189,24 @@ This is likely because all of the data points are used to estimate $a$ and $\sig
261
189
The effective sample size and $\widehat{R}$ values all look good, indiciating that the model has converged and sampled well from the posterior distribution.
262
190
263
191
```{code-cell} ipython3
264
-
az.summary(trace, var_names=["a", "w", "sigma"])
192
+
az.summary(idata, var_names=["a", "w", "sigma"])
265
193
```
266
194
267
-
The trace plots of the model parameters look good (fuzzy caterpillars), further indicating that the chains converged and mixed.
195
+
The trace plots of the model parameters look good (homogeneous and no sign of trend), further indicating that the chains converged and mixed.
Another visualization of the fit spline values is to plot them multiplied against the basis matrix.
278
-
The knot boundaries are shown in gray again, but now the spline basis is multipled against the values of $w$ (represented as the rainbow-colored curves).
279
-
The dot product of $B$ and $w$ – the actual computation in the linear model – is shown in blue.
206
+
The knot boundaries are shown as vertical lines again, but now the spline basis is multipled against the values of $w$ (represented as the rainbow-colored curves). The dot product of $B$ and $w$ – the actual computation in the linear model – is shown in black.
title="Cherry blossom data with posterior predictions",
328
-
ylabel="Day of Year",
255
+
ylabel="Days in bloom",
329
256
)
330
257
for knot in knot_list:
331
258
plt.gca().axvline(knot, color="grey", alpha=0.4)
@@ -340,33 +267,20 @@ plt.fill_between(
340
267
);
341
268
```
342
269
343
-
## Authors
344
-
345
-
- Authored by Joshua Cook in October, 2021
346
-
- Updated by [Tyler James Burch](https://github.com/tjburch) in October, 2021
347
-
348
-
+++
349
-
350
270
## References
351
271
352
272
:::{bibliography}
353
273
:filter: docname in docnames
354
274
:::
355
275
356
-
I would like to recognize the discussion ["Spline Regression in PyMC3"](https://discourse.pymc.io/t/spline-regression-in-pymc3/6235) on the PyMC3 Discourse as the inspiration of this example and for the helpful discussion and problem-solving that improved it further.
0 commit comments