-
-
Notifications
You must be signed in to change notification settings - Fork 269
Updated splines notebook to v4 #274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
931e48f
0bce126
27ac7f4
07e6c1a
abc5031
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,17 +6,17 @@ jupytext: | |
format_version: 0.13 | ||
jupytext_version: 1.13.7 | ||
kernelspec: | ||
display_name: Python 3 (ipykernel) | ||
display_name: Python 3.9.10 ('pymc-dev-py39') | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
# Splines in PyMC3 | ||
|
||
:::{post} Oct 8, 2021 | ||
:tags: patsy, pymc3.Deterministic, pymc3.Exponential, pymc3.Model, pymc3.Normal, regression, spline | ||
:::{post} May 6, 2022 | ||
:tags: patsy, pymc.Deterministic, pymc.Exponential, pymc.Model, pymc.Normal, regression, spline | ||
:category: beginner | ||
:author: Joshua Cook, Tyler James Burch | ||
:author: Joshua Cook updated by Tyler James Burch, Chris Fonnesbeck | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is metadata used to generate the citation advise at the bottom. It should only have names and exclude those who have only re-executed the notebooks without making changes. The Authors section should always be present and list everyone and the changes they did. We discussed that both in an issue, in the docs channel in slack and in a couple doc meetings. Like I often say, I do not have strong feelings about what was chosen and would have been perfectly happy (or even more happy) with another choice or with changing this now. What I do have strong feelings about is that all notebooks do the same because otherwise some people don't dare add their name at all whereas others do which only increases already present inequalities. Therefore changing that would mean changing the style guide and updating all the notebooks that are already in the Done column. |
||
::: | ||
|
||
+++ {"tags": []} | ||
|
@@ -27,22 +27,10 @@ Often, the model we want to fit is not a perfect line between some $x$ and $y$. | |
Instead, the parameters of the model are expected to vary over $x$. | ||
There are multiple ways to handle this situation, one of which is to fit a *spline*. | ||
The spline is effectively multiple individual lines, each fit to a different section of $x$, that are tied together at their boundaries, often called *knots*. | ||
Below is an exmaple of how to fit a spline using the Bayesian framework [PyMC3](https://docs.pymc.io). | ||
|
||
Below is a full working example of how to fit a spline using the probabilitic programming language PyMC3. | ||
The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}`mcelreath2018statistical`. | ||
As the book uses [Stan](https://mc-stan.org) (another advanced probabilitistic programming language), the modeling code is primarily taken from the [GitHub repository of the PyMC3 implementation of *Statistical Rethinking*](https://github.com/pymc-devs/resources/blob/master/Rethinking_2/Chp_04.ipynb). | ||
My contributions are primarily of explanation and additional analyses of the data and results. | ||
Below is a full working example of how to fit a spline using PyMC. The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}`mcelreath2018statistical`. | ||
|
||
**Note that this is not a comprehensive review of splines – I primarily focus on the implementation in PyMC3.** | ||
For more information on this method of non-linear modeling, I suggesting beginning with chapter 7.4 "Regression Splines" of *An Introduction to Statistical Learning* {cite:p}`james2021statisticallearning`. | ||
|
||
+++ | ||
|
||
## Setup | ||
|
||
For this example, I employ the standard data science and Bayesian data analysis packages. | ||
In addition, the ['patsy'](https://patsy.readthedocs.io/en/latest/) library is used to generate the basis for the spline (more on that below). | ||
For more information on this method of non-linear modeling, I suggesting beginning with [chapter 5 of Bayesian Modeling and Computation in Python](https://bayesiancomputationbook.com/markdown/chp_05.html) {cite:p}`martin2021bayesian`. | ||
|
||
```{code-cell} ipython3 | ||
from pathlib import Path | ||
|
@@ -51,8 +39,7 @@ import arviz as az | |
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
import pymc3 as pm | ||
import statsmodels.api as sm | ||
import pymc as pm | ||
|
||
from patsy import dmatrix | ||
``` | ||
|
@@ -62,14 +49,13 @@ from patsy import dmatrix | |
%config InlineBackend.figure_format = "retina" | ||
|
||
RANDOM_SEED = 8927 | ||
rng = np.random.default_rng(RANDOM_SEED) | ||
az.style.use("arviz-darkgrid") | ||
``` | ||
|
||
## Cherry blossom data | ||
|
||
The data for this example was the number of days (`doy` for "days of year") that the cherry trees were in bloom in each year (`year`). | ||
Years missing a `doy` were dropped. | ||
The data for this example is the number of days (`doy` for "days of year") that the cherry trees were in bloom in each year (`year`). | ||
For convenience, years missing a `doy` were dropped (which is a bad idea to deal with missing data in general!). | ||
|
||
```{code-cell} ipython3 | ||
try: | ||
|
@@ -92,11 +78,11 @@ After dropping rows with missing data, there are 827 years with the numbers of d | |
blossom_data.shape | ||
``` | ||
|
||
Below is a plot of the data we will be modeling showing the number of days of bloom per year. | ||
If we visualize the data, it is clear that there a lot of annual variation, but some evidence for a non-linear trend in bloom days over time. | ||
|
||
```{code-cell} ipython3 | ||
blossom_data.plot.scatter( | ||
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year" | ||
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Days in bloom" | ||
); | ||
``` | ||
|
||
|
@@ -112,15 +98,11 @@ $\qquad a \sim \mathcal{N}(100, 10)$ | |
$\qquad w \sim \mathcal{N}(0, 10)$ | ||
$\quad \sigma \sim \text{Exp}(1)$ | ||
|
||
The number of days of bloom will be modeled as a normal distribution with mean $\mu$ and standard deviation $\sigma$. | ||
The mean will be a linear model composed of a y-intercept $a$ and spline defined by the basis $B$ multiplied by the model parameter $w$ with a variable for each region of the basis. | ||
Both have relatively weak normal priors. | ||
The number of days of bloom $D$ will be modeled as a normal distribution with mean $\mu$ and standard deviation $\sigma$. In turn, the mean will be a linear model composed of a y-intercept $a$ and spline defined by the basis $B$ multiplied by the model parameter $w$ with a variable for each region of the basis. Both have relatively weak normal priors. | ||
|
||
### Prepare the spline | ||
|
||
The spline will have 15 *knots*, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data). | ||
The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve. | ||
The knots will be unevenly spaced over the years such that each region will have the same proportion of data. | ||
The spline will have 15 *knots*, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data). The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve. The knots will be unevenly spaced over the years such that each region will have the same proportion of data. | ||
|
||
```{code-cell} ipython3 | ||
num_knots = 15 | ||
|
@@ -138,50 +120,7 @@ for knot in knot_list: | |
plt.gca().axvline(knot, color="grey", alpha=0.4); | ||
``` | ||
|
||
Before doing any Bayesian modeling of the spline, we can get an idea of what our model should look like using the lowess modeling from `statsmodels` | ||
|
||
```{code-cell} ipython3 | ||
blossom_data.plot.scatter( | ||
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year" | ||
) | ||
for knot in knot_list: | ||
plt.gca().axvline(knot, color="grey", alpha=0.4) | ||
|
||
lowess = sm.nonparametric.lowess | ||
lowess_data = lowess(blossom_data.doy, blossom_data.year, frac=0.2, it=10) | ||
plt.plot(lowess_data[:, 0], lowess_data[:, 1], color="firebrick", lw=2); | ||
``` | ||
|
||
Another way of visualizing what the spline should look like is to plot individual linear models over the data between each knot. | ||
The spline will effectively be a compromise between these individual models and a continuous curve. | ||
|
||
```{code-cell} ipython3 | ||
blossom_data["knot_group"] = [np.where(a <= knot_list)[0][0] for a in blossom_data.year] | ||
blossom_data["knot_group"] = pd.Categorical(blossom_data["knot_group"], ordered=True) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
blossom_data.plot.scatter( | ||
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year" | ||
) | ||
for knot in knot_list: | ||
plt.gca().axvline(knot, color="grey", alpha=0.4) | ||
|
||
for i in np.arange(len(knot_list) - 1): | ||
|
||
# Subset data to knot | ||
x_range = (knot_list[i], knot_list[i + 1]) | ||
subset = blossom_data.query(f"year > {x_range[0]} & year <= {x_range[1]}") | ||
|
||
# Create a linear model and predict values | ||
lm = sm.OLS(subset.doy, sm.add_constant(subset.year, prepend=False)).fit() | ||
x_vals = np.linspace(x_range[0], x_range[1], 100) | ||
y_vals = lm.predict(sm.add_constant(x_vals, prepend=False)) | ||
# Add to plot | ||
plt.plot(x_vals, y_vals, color="firebrick", lw=2) | ||
``` | ||
|
||
Finally we can use 'patsy' to create the matrix $B$ that will be the b-spline basis for the regression. | ||
We can use `patsy` to create the matrix $B$ that will be the b-spline basis for the regression. | ||
The degree is set to 3 to create a cubic b-spline. | ||
|
||
```{code-cell} ipython3 | ||
|
@@ -194,10 +133,7 @@ B = dmatrix( | |
B | ||
``` | ||
|
||
The b-spline basis is plotted below, showing the "domain" of each piece of the spline. | ||
The height of each curve indicates how "influential" the corresponding model covariate (one per spline region) will be on model's "inference" of that region. | ||
(The quotes are to indicate that these words were chosen to help with interpretation and are not the proper mathematical terms.) | ||
The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed. | ||
The b-spline basis is plotted below, showing the *domain* of each piece of the spline. The height of each curve indicates how influential the corresponding model covariate (one per spline region) will be on model's inference of that region. The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed. | ||
|
||
```{code-cell} ipython3 | ||
spline_df = ( | ||
|
@@ -217,17 +153,16 @@ plt.legend(title="Spline Index", loc="upper center", fontsize=8, ncol=6); | |
|
||
### Fit the model | ||
|
||
Finally, the model can be built using PyMC3. | ||
A graphical diagram shows the organization of the model parameters (note that this requires the installation of 'python-graphviz' which is easiest in a `conda` virtual environment). | ||
Finally, the model can be built using PyMC. A graphical diagram shows the organization of the model parameters (note that this requires the installation of `python-graphviz`, which I recommend doing in a `conda` virtual environment). | ||
|
||
```{code-cell} ipython3 | ||
COORDS = {"obs": np.arange(len(blossom_data.doy)), "splines": np.arange(B.shape[1])} | ||
with pm.Model(coords=COORDS) as spline_model: | ||
COORDS = {"splines": np.arange(B.shape[1])} | ||
with pm.Model(coords=COORDS, rng_seeder=RANDOM_SEED) as spline_model: | ||
a = pm.Normal("a", 100, 5) | ||
w = pm.Normal("w", mu=0, sd=3, dims="splines") | ||
w = pm.Normal("w", mu=0, sigma=3, size=B.shape[1], dims="splines") | ||
mu = pm.Deterministic("mu", a + pm.math.dot(np.asarray(B, order="F"), w.T)) | ||
sigma = pm.Exponential("sigma", 1) | ||
D = pm.Normal("D", mu, sigma, observed=blossom_data.doy, dims="obs") | ||
D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy, dims="obs") | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
|
@@ -236,23 +171,16 @@ pm.model_to_graphviz(spline_model) | |
|
||
```{code-cell} ipython3 | ||
with spline_model: | ||
prior_pred = pm.sample_prior_predictive(random_seed=RANDOM_SEED) | ||
trace = pm.sample( | ||
draws=1000, | ||
tune=1000, | ||
random_seed=RANDOM_SEED, | ||
chains=4, | ||
return_inferencedata=True, | ||
) | ||
post_pred = pm.sample_posterior_predictive(trace, random_seed=RANDOM_SEED) | ||
trace.extend(az.from_pymc3(prior=prior_pred, posterior_predictive=post_pred)) | ||
idata = pm.sample_prior_predictive() | ||
idata.extend(pm.sample(draws=1000, tune=1000, random_seed=RANDOM_SEED, chains=4)) | ||
pm.sample_posterior_predictive(idata, extend_inferencedata=True) | ||
``` | ||
|
||
## Analysis | ||
|
||
Now we can analyze the draws from the posterior of the model. | ||
|
||
### Fit parameters | ||
### Parameter Estimates | ||
|
||
Below is a table summarizing the posterior distributions of the model parameters. | ||
The posteriors of $a$ and $\sigma$ are quite narrow while those for $w$ are wider. | ||
|
@@ -261,25 +189,24 @@ This is likely because all of the data points are used to estimate $a$ and $\sig | |
The effective sample size and $\widehat{R}$ values all look good, indiciating that the model has converged and sampled well from the posterior distribution. | ||
|
||
```{code-cell} ipython3 | ||
az.summary(trace, var_names=["a", "w", "sigma"]) | ||
az.summary(idata, var_names=["a", "w", "sigma"]) | ||
``` | ||
|
||
The trace plots of the model parameters look good (fuzzy caterpillars), further indicating that the chains converged and mixed. | ||
The trace plots of the model parameters look good (homogeneous and no sign of trend), further indicating that the chains converged and mixed. | ||
|
||
```{code-cell} ipython3 | ||
az.plot_trace(trace, var_names=["a", "w", "sigma"]); | ||
az.plot_trace(idata, var_names=["a", "w", "sigma"]); | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
az.plot_forest(trace, var_names=["w"], combined=False); | ||
az.plot_forest(idata, var_names=["w"], combined=False, r_hat=True); | ||
``` | ||
|
||
Another visualization of the fit spline values is to plot them multiplied against the basis matrix. | ||
The knot boundaries are shown in gray again, but now the spline basis is multipled against the values of $w$ (represented as the rainbow-colored curves). | ||
The dot product of $B$ and $w$ – the actual computation in the linear model – is shown in blue. | ||
The knot boundaries are shown as vertical lines again, but now the spline basis is multipled against the values of $w$ (represented as the rainbow-colored curves). The dot product of $B$ and $w$ – the actual computation in the linear model – is shown in black. | ||
|
||
```{code-cell} ipython3 | ||
wp = trace.posterior["w"].values.mean(axis=(0, 1)) | ||
wp = idata.posterior["w"].mean(("chain", "draw")).values | ||
|
||
spline_df = ( | ||
pd.DataFrame(B * wp.T) | ||
|
@@ -311,7 +238,7 @@ for knot in knot_list: | |
Lastly, we can visualize the predictions of the model using the posterior predictive check. | ||
|
||
```{code-cell} ipython3 | ||
post_pred = az.summary(trace, var_names=["mu"]).reset_index(drop=True) | ||
post_pred = az.summary(idata, var_names=["mu"]).reset_index(drop=True) | ||
blossom_data_post = blossom_data.copy().reset_index(drop=True) | ||
blossom_data_post["pred_mean"] = post_pred["mean"] | ||
blossom_data_post["pred_hdi_lower"] = post_pred["hdi_3%"] | ||
|
@@ -325,7 +252,7 @@ blossom_data.plot.scatter( | |
color="cornflowerblue", | ||
s=10, | ||
title="Cherry blossom data with posterior predictions", | ||
ylabel="Day of Year", | ||
ylabel="Days in bloom", | ||
) | ||
for knot in knot_list: | ||
plt.gca().axvline(knot, color="grey", alpha=0.4) | ||
|
@@ -340,33 +267,20 @@ plt.fill_between( | |
); | ||
``` | ||
|
||
## Authors | ||
|
||
- Authored by Joshua Cook in October, 2021 | ||
- Updated by [Tyler James Burch](https://github.com/tjburch) in October, 2021 | ||
|
||
+++ | ||
|
||
## References | ||
|
||
:::{bibliography} | ||
:filter: docname in docnames | ||
::: | ||
|
||
I would like to recognize the discussion ["Spline Regression in PyMC3"](https://discourse.pymc.io/t/spline-regression-in-pymc3/6235) on the PyMC3 Discourse as the inspiration of this example and for the helpful discussion and problem-solving that improved it further. | ||
|
||
+++ | ||
|
||
## Watermark | ||
|
||
```{code-cell} ipython3 | ||
%load_ext watermark | ||
%watermark -n -u -v -iv -w -p theano,xarray,patsy | ||
%watermark -n -u -v -iv -w -p aesara,xarray,patsy | ||
``` | ||
|
||
:::{include} ../page_footer.md | ||
::: | ||
|
||
```{code-cell} ipython3 | ||
|
||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should not be
pymc.<object>
tags anymore, that generates way too many tags that are not that useful so we decided to use https://docs.pymc.io/projects/examples/en/latest/object_index/index.html instead for this. The issue to track progress on this is #289