Skip to content

Commit 1665736

Browse files
authored
Updated splines notebook to v4 (#274)
* Updated splines notebook to v4 * Re-run for new output * Fixes to spline notebook based on feedback * Improved plot axis labels
1 parent acc68bb commit 1665736

File tree

2 files changed

+448
-655
lines changed

2 files changed

+448
-655
lines changed

examples/splines/spline.ipynb

+415-536
Large diffs are not rendered by default.

myst_nbs/splines/spline.myst.md

+33-119
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@ jupytext:
66
format_version: 0.13
77
jupytext_version: 1.13.7
88
kernelspec:
9-
display_name: Python 3 (ipykernel)
9+
display_name: Python 3.9.10 ('pymc-dev-py39')
1010
language: python
1111
name: python3
1212
---
1313

1414
# Splines in PyMC3
1515

16-
:::{post} Oct 8, 2021
17-
:tags: patsy, pymc3.Deterministic, pymc3.Exponential, pymc3.Model, pymc3.Normal, regression, spline
16+
:::{post} May 6, 2022
17+
:tags: patsy, pymc.Deterministic, pymc.Exponential, pymc.Model, pymc.Normal, regression, spline
1818
:category: beginner
19-
:author: Joshua Cook, Tyler James Burch
19+
:author: Joshua Cook updated by Tyler James Burch, Chris Fonnesbeck
2020
:::
2121

2222
+++ {"tags": []}
@@ -27,22 +27,10 @@ Often, the model we want to fit is not a perfect line between some $x$ and $y$.
2727
Instead, the parameters of the model are expected to vary over $x$.
2828
There are multiple ways to handle this situation, one of which is to fit a *spline*.
2929
The spline is effectively multiple individual lines, each fit to a different section of $x$, that are tied together at their boundaries, often called *knots*.
30-
Below is an exmaple of how to fit a spline using the Bayesian framework [PyMC3](https://docs.pymc.io).
3130

32-
Below is a full working example of how to fit a spline using the probabilitic programming language PyMC3.
33-
The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}`mcelreath2018statistical`.
34-
As the book uses [Stan](https://mc-stan.org) (another advanced probabilitistic programming language), the modeling code is primarily taken from the [GitHub repository of the PyMC3 implementation of *Statistical Rethinking*](https://github.com/pymc-devs/resources/blob/master/Rethinking_2/Chp_04.ipynb).
35-
My contributions are primarily of explanation and additional analyses of the data and results.
31+
Below is a full working example of how to fit a spline using PyMC. The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}`mcelreath2018statistical`.
3632

37-
**Note that this is not a comprehensive review of splines – I primarily focus on the implementation in PyMC3.**
38-
For more information on this method of non-linear modeling, I suggesting beginning with chapter 7.4 "Regression Splines" of *An Introduction to Statistical Learning* {cite:p}`james2021statisticallearning`.
39-
40-
+++
41-
42-
## Setup
43-
44-
For this example, I employ the standard data science and Bayesian data analysis packages.
45-
In addition, the ['patsy'](https://patsy.readthedocs.io/en/latest/) library is used to generate the basis for the spline (more on that below).
33+
For more information on this method of non-linear modeling, I suggesting beginning with [chapter 5 of Bayesian Modeling and Computation in Python](https://bayesiancomputationbook.com/markdown/chp_05.html) {cite:p}`martin2021bayesian`.
4634

4735
```{code-cell} ipython3
4836
from pathlib import Path
@@ -51,8 +39,7 @@ import arviz as az
5139
import matplotlib.pyplot as plt
5240
import numpy as np
5341
import pandas as pd
54-
import pymc3 as pm
55-
import statsmodels.api as sm
42+
import pymc as pm
5643
5744
from patsy import dmatrix
5845
```
@@ -62,14 +49,13 @@ from patsy import dmatrix
6249
%config InlineBackend.figure_format = "retina"
6350
6451
RANDOM_SEED = 8927
65-
rng = np.random.default_rng(RANDOM_SEED)
6652
az.style.use("arviz-darkgrid")
6753
```
6854

6955
## Cherry blossom data
7056

71-
The data for this example was the number of days (`doy` for "days of year") that the cherry trees were in bloom in each year (`year`).
72-
Years missing a `doy` were dropped.
57+
The data for this example is the number of days (`doy` for "days of year") that the cherry trees were in bloom in each year (`year`).
58+
For convenience, years missing a `doy` were dropped (which is a bad idea to deal with missing data in general!).
7359

7460
```{code-cell} ipython3
7561
try:
@@ -92,11 +78,11 @@ After dropping rows with missing data, there are 827 years with the numbers of d
9278
blossom_data.shape
9379
```
9480

95-
Below is a plot of the data we will be modeling showing the number of days of bloom per year.
81+
If we visualize the data, it is clear that there a lot of annual variation, but some evidence for a non-linear trend in bloom days over time.
9682

9783
```{code-cell} ipython3
9884
blossom_data.plot.scatter(
99-
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
85+
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Days in bloom"
10086
);
10187
```
10288

@@ -112,15 +98,11 @@ $\qquad a \sim \mathcal{N}(100, 10)$
11298
$\qquad w \sim \mathcal{N}(0, 10)$
11399
$\quad \sigma \sim \text{Exp}(1)$
114100

115-
The number of days of bloom will be modeled as a normal distribution with mean $\mu$ and standard deviation $\sigma$.
116-
The mean will be a linear model composed of a y-intercept $a$ and spline defined by the basis $B$ multiplied by the model parameter $w$ with a variable for each region of the basis.
117-
Both have relatively weak normal priors.
101+
The number of days of bloom $D$ will be modeled as a normal distribution with mean $\mu$ and standard deviation $\sigma$. In turn, the mean will be a linear model composed of a y-intercept $a$ and spline defined by the basis $B$ multiplied by the model parameter $w$ with a variable for each region of the basis. Both have relatively weak normal priors.
118102

119103
### Prepare the spline
120104

121-
The spline will have 15 *knots*, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data).
122-
The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve.
123-
The knots will be unevenly spaced over the years such that each region will have the same proportion of data.
105+
The spline will have 15 *knots*, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data). The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve. The knots will be unevenly spaced over the years such that each region will have the same proportion of data.
124106

125107
```{code-cell} ipython3
126108
num_knots = 15
@@ -138,50 +120,7 @@ for knot in knot_list:
138120
plt.gca().axvline(knot, color="grey", alpha=0.4);
139121
```
140122

141-
Before doing any Bayesian modeling of the spline, we can get an idea of what our model should look like using the lowess modeling from `statsmodels`
142-
143-
```{code-cell} ipython3
144-
blossom_data.plot.scatter(
145-
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
146-
)
147-
for knot in knot_list:
148-
plt.gca().axvline(knot, color="grey", alpha=0.4)
149-
150-
lowess = sm.nonparametric.lowess
151-
lowess_data = lowess(blossom_data.doy, blossom_data.year, frac=0.2, it=10)
152-
plt.plot(lowess_data[:, 0], lowess_data[:, 1], color="firebrick", lw=2);
153-
```
154-
155-
Another way of visualizing what the spline should look like is to plot individual linear models over the data between each knot.
156-
The spline will effectively be a compromise between these individual models and a continuous curve.
157-
158-
```{code-cell} ipython3
159-
blossom_data["knot_group"] = [np.where(a <= knot_list)[0][0] for a in blossom_data.year]
160-
blossom_data["knot_group"] = pd.Categorical(blossom_data["knot_group"], ordered=True)
161-
```
162-
163-
```{code-cell} ipython3
164-
blossom_data.plot.scatter(
165-
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
166-
)
167-
for knot in knot_list:
168-
plt.gca().axvline(knot, color="grey", alpha=0.4)
169-
170-
for i in np.arange(len(knot_list) - 1):
171-
172-
# Subset data to knot
173-
x_range = (knot_list[i], knot_list[i + 1])
174-
subset = blossom_data.query(f"year > {x_range[0]} & year <= {x_range[1]}")
175-
176-
# Create a linear model and predict values
177-
lm = sm.OLS(subset.doy, sm.add_constant(subset.year, prepend=False)).fit()
178-
x_vals = np.linspace(x_range[0], x_range[1], 100)
179-
y_vals = lm.predict(sm.add_constant(x_vals, prepend=False))
180-
# Add to plot
181-
plt.plot(x_vals, y_vals, color="firebrick", lw=2)
182-
```
183-
184-
Finally we can use 'patsy' to create the matrix $B$ that will be the b-spline basis for the regression.
123+
We can use `patsy` to create the matrix $B$ that will be the b-spline basis for the regression.
185124
The degree is set to 3 to create a cubic b-spline.
186125

187126
```{code-cell} ipython3
@@ -194,10 +133,7 @@ B = dmatrix(
194133
B
195134
```
196135

197-
The b-spline basis is plotted below, showing the "domain" of each piece of the spline.
198-
The height of each curve indicates how "influential" the corresponding model covariate (one per spline region) will be on model's "inference" of that region.
199-
(The quotes are to indicate that these words were chosen to help with interpretation and are not the proper mathematical terms.)
200-
The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed.
136+
The b-spline basis is plotted below, showing the *domain* of each piece of the spline. The height of each curve indicates how influential the corresponding model covariate (one per spline region) will be on model's inference of that region. The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed.
201137

202138
```{code-cell} ipython3
203139
spline_df = (
@@ -217,17 +153,16 @@ plt.legend(title="Spline Index", loc="upper center", fontsize=8, ncol=6);
217153

218154
### Fit the model
219155

220-
Finally, the model can be built using PyMC3.
221-
A graphical diagram shows the organization of the model parameters (note that this requires the installation of 'python-graphviz' which is easiest in a `conda` virtual environment).
156+
Finally, the model can be built using PyMC. A graphical diagram shows the organization of the model parameters (note that this requires the installation of `python-graphviz`, which I recommend doing in a `conda` virtual environment).
222157

223158
```{code-cell} ipython3
224-
COORDS = {"obs": np.arange(len(blossom_data.doy)), "splines": np.arange(B.shape[1])}
225-
with pm.Model(coords=COORDS) as spline_model:
159+
COORDS = {"splines": np.arange(B.shape[1])}
160+
with pm.Model(coords=COORDS, rng_seeder=RANDOM_SEED) as spline_model:
226161
a = pm.Normal("a", 100, 5)
227-
w = pm.Normal("w", mu=0, sd=3, dims="splines")
162+
w = pm.Normal("w", mu=0, sigma=3, size=B.shape[1], dims="splines")
228163
mu = pm.Deterministic("mu", a + pm.math.dot(np.asarray(B, order="F"), w.T))
229164
sigma = pm.Exponential("sigma", 1)
230-
D = pm.Normal("D", mu, sigma, observed=blossom_data.doy, dims="obs")
165+
D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy, dims="obs")
231166
```
232167

233168
```{code-cell} ipython3
@@ -236,23 +171,16 @@ pm.model_to_graphviz(spline_model)
236171

237172
```{code-cell} ipython3
238173
with spline_model:
239-
prior_pred = pm.sample_prior_predictive(random_seed=RANDOM_SEED)
240-
trace = pm.sample(
241-
draws=1000,
242-
tune=1000,
243-
random_seed=RANDOM_SEED,
244-
chains=4,
245-
return_inferencedata=True,
246-
)
247-
post_pred = pm.sample_posterior_predictive(trace, random_seed=RANDOM_SEED)
248-
trace.extend(az.from_pymc3(prior=prior_pred, posterior_predictive=post_pred))
174+
idata = pm.sample_prior_predictive()
175+
idata.extend(pm.sample(draws=1000, tune=1000, random_seed=RANDOM_SEED, chains=4))
176+
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
249177
```
250178

251179
## Analysis
252180

253181
Now we can analyze the draws from the posterior of the model.
254182

255-
### Fit parameters
183+
### Parameter Estimates
256184

257185
Below is a table summarizing the posterior distributions of the model parameters.
258186
The posteriors of $a$ and $\sigma$ are quite narrow while those for $w$ are wider.
@@ -261,25 +189,24 @@ This is likely because all of the data points are used to estimate $a$ and $\sig
261189
The effective sample size and $\widehat{R}$ values all look good, indiciating that the model has converged and sampled well from the posterior distribution.
262190

263191
```{code-cell} ipython3
264-
az.summary(trace, var_names=["a", "w", "sigma"])
192+
az.summary(idata, var_names=["a", "w", "sigma"])
265193
```
266194

267-
The trace plots of the model parameters look good (fuzzy caterpillars), further indicating that the chains converged and mixed.
195+
The trace plots of the model parameters look good (homogeneous and no sign of trend), further indicating that the chains converged and mixed.
268196

269197
```{code-cell} ipython3
270-
az.plot_trace(trace, var_names=["a", "w", "sigma"]);
198+
az.plot_trace(idata, var_names=["a", "w", "sigma"]);
271199
```
272200

273201
```{code-cell} ipython3
274-
az.plot_forest(trace, var_names=["w"], combined=False);
202+
az.plot_forest(idata, var_names=["w"], combined=False, r_hat=True);
275203
```
276204

277205
Another visualization of the fit spline values is to plot them multiplied against the basis matrix.
278-
The knot boundaries are shown in gray again, but now the spline basis is multipled against the values of $w$ (represented as the rainbow-colored curves).
279-
The dot product of $B$ and $w$ – the actual computation in the linear model – is shown in blue.
206+
The knot boundaries are shown as vertical lines again, but now the spline basis is multipled against the values of $w$ (represented as the rainbow-colored curves). The dot product of $B$ and $w$ – the actual computation in the linear model – is shown in black.
280207

281208
```{code-cell} ipython3
282-
wp = trace.posterior["w"].values.mean(axis=(0, 1))
209+
wp = idata.posterior["w"].mean(("chain", "draw")).values
283210
284211
spline_df = (
285212
pd.DataFrame(B * wp.T)
@@ -311,7 +238,7 @@ for knot in knot_list:
311238
Lastly, we can visualize the predictions of the model using the posterior predictive check.
312239

313240
```{code-cell} ipython3
314-
post_pred = az.summary(trace, var_names=["mu"]).reset_index(drop=True)
241+
post_pred = az.summary(idata, var_names=["mu"]).reset_index(drop=True)
315242
blossom_data_post = blossom_data.copy().reset_index(drop=True)
316243
blossom_data_post["pred_mean"] = post_pred["mean"]
317244
blossom_data_post["pred_hdi_lower"] = post_pred["hdi_3%"]
@@ -325,7 +252,7 @@ blossom_data.plot.scatter(
325252
color="cornflowerblue",
326253
s=10,
327254
title="Cherry blossom data with posterior predictions",
328-
ylabel="Day of Year",
255+
ylabel="Days in bloom",
329256
)
330257
for knot in knot_list:
331258
plt.gca().axvline(knot, color="grey", alpha=0.4)
@@ -340,33 +267,20 @@ plt.fill_between(
340267
);
341268
```
342269

343-
## Authors
344-
345-
- Authored by Joshua Cook in October, 2021
346-
- Updated by [Tyler James Burch](https://github.com/tjburch) in October, 2021
347-
348-
+++
349-
350270
## References
351271

352272
:::{bibliography}
353273
:filter: docname in docnames
354274
:::
355275

356-
I would like to recognize the discussion ["Spline Regression in PyMC3"](https://discourse.pymc.io/t/spline-regression-in-pymc3/6235) on the PyMC3 Discourse as the inspiration of this example and for the helpful discussion and problem-solving that improved it further.
357-
358276
+++
359277

360278
## Watermark
361279

362280
```{code-cell} ipython3
363281
%load_ext watermark
364-
%watermark -n -u -v -iv -w -p theano,xarray,patsy
282+
%watermark -n -u -v -iv -w -p aesara,xarray,patsy
365283
```
366284

367285
:::{include} ../page_footer.md
368286
:::
369-
370-
```{code-cell} ipython3
371-
372-
```

0 commit comments

Comments
 (0)