Skip to content

Updated splines notebook to v4 #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
951 changes: 415 additions & 536 deletions examples/splines/spline.ipynb

Large diffs are not rendered by default.

152 changes: 33 additions & 119 deletions myst_nbs/splines/spline.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ jupytext:
format_version: 0.13
jupytext_version: 1.13.7
kernelspec:
display_name: Python 3 (ipykernel)
display_name: Python 3.9.10 ('pymc-dev-py39')
language: python
name: python3
---

# Splines in PyMC3

:::{post} Oct 8, 2021
:tags: patsy, pymc3.Deterministic, pymc3.Exponential, pymc3.Model, pymc3.Normal, regression, spline
:::{post} May 6, 2022
:tags: patsy, pymc.Deterministic, pymc.Exponential, pymc.Model, pymc.Normal, regression, spline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should not be pymc.<object> tags anymore, that generates way too many tags that are not that useful so we decided to use https://docs.pymc.io/projects/examples/en/latest/object_index/index.html instead for this. The issue to track progress on this is #289

:category: beginner
:author: Joshua Cook, Tyler James Burch
:author: Joshua Cook updated by Tyler James Burch, Chris Fonnesbeck
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is metadata used to generate the citation advise at the bottom. It should only have names and exclude those who have only re-executed the notebooks without making changes. The Authors section should always be present and list everyone and the changes they did.

We discussed that both in an issue, in the docs channel in slack and in a couple doc meetings. Like I often say, I do not have strong feelings about what was chosen and would have been perfectly happy (or even more happy) with another choice or with changing this now. What I do have strong feelings about is that all notebooks do the same because otherwise some people don't dare add their name at all whereas others do which only increases already present inequalities. Therefore changing that would mean changing the style guide and updating all the notebooks that are already in the Done column.

:::

+++ {"tags": []}
Expand All @@ -27,22 +27,10 @@ Often, the model we want to fit is not a perfect line between some $x$ and $y$.
Instead, the parameters of the model are expected to vary over $x$.
There are multiple ways to handle this situation, one of which is to fit a *spline*.
The spline is effectively multiple individual lines, each fit to a different section of $x$, that are tied together at their boundaries, often called *knots*.
Below is an exmaple of how to fit a spline using the Bayesian framework [PyMC3](https://docs.pymc.io).

Below is a full working example of how to fit a spline using the probabilitic programming language PyMC3.
The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}`mcelreath2018statistical`.
As the book uses [Stan](https://mc-stan.org) (another advanced probabilitistic programming language), the modeling code is primarily taken from the [GitHub repository of the PyMC3 implementation of *Statistical Rethinking*](https://github.com/pymc-devs/resources/blob/master/Rethinking_2/Chp_04.ipynb).
My contributions are primarily of explanation and additional analyses of the data and results.
Below is a full working example of how to fit a spline using PyMC. The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}`mcelreath2018statistical`.

**Note that this is not a comprehensive review of splines – I primarily focus on the implementation in PyMC3.**
For more information on this method of non-linear modeling, I suggesting beginning with chapter 7.4 "Regression Splines" of *An Introduction to Statistical Learning* {cite:p}`james2021statisticallearning`.

+++

## Setup

For this example, I employ the standard data science and Bayesian data analysis packages.
In addition, the ['patsy'](https://patsy.readthedocs.io/en/latest/) library is used to generate the basis for the spline (more on that below).
For more information on this method of non-linear modeling, I suggesting beginning with [chapter 5 of Bayesian Modeling and Computation in Python](https://bayesiancomputationbook.com/markdown/chp_05.html) {cite:p}`martin2021bayesian`.

```{code-cell} ipython3
from pathlib import Path
Expand All @@ -51,8 +39,7 @@ import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import statsmodels.api as sm
import pymc as pm

from patsy import dmatrix
```
Expand All @@ -62,14 +49,13 @@ from patsy import dmatrix
%config InlineBackend.figure_format = "retina"

RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")
```

## Cherry blossom data

The data for this example was the number of days (`doy` for "days of year") that the cherry trees were in bloom in each year (`year`).
Years missing a `doy` were dropped.
The data for this example is the number of days (`doy` for "days of year") that the cherry trees were in bloom in each year (`year`).
For convenience, years missing a `doy` were dropped (which is a bad idea to deal with missing data in general!).

```{code-cell} ipython3
try:
Expand All @@ -92,11 +78,11 @@ After dropping rows with missing data, there are 827 years with the numbers of d
blossom_data.shape
```

Below is a plot of the data we will be modeling showing the number of days of bloom per year.
If we visualize the data, it is clear that there a lot of annual variation, but some evidence for a non-linear trend in bloom days over time.

```{code-cell} ipython3
blossom_data.plot.scatter(
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Days in bloom"
);
```

Expand All @@ -112,15 +98,11 @@ $\qquad a \sim \mathcal{N}(100, 10)$
$\qquad w \sim \mathcal{N}(0, 10)$
$\quad \sigma \sim \text{Exp}(1)$

The number of days of bloom will be modeled as a normal distribution with mean $\mu$ and standard deviation $\sigma$.
The mean will be a linear model composed of a y-intercept $a$ and spline defined by the basis $B$ multiplied by the model parameter $w$ with a variable for each region of the basis.
Both have relatively weak normal priors.
The number of days of bloom $D$ will be modeled as a normal distribution with mean $\mu$ and standard deviation $\sigma$. In turn, the mean will be a linear model composed of a y-intercept $a$ and spline defined by the basis $B$ multiplied by the model parameter $w$ with a variable for each region of the basis. Both have relatively weak normal priors.

### Prepare the spline

The spline will have 15 *knots*, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data).
The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve.
The knots will be unevenly spaced over the years such that each region will have the same proportion of data.
The spline will have 15 *knots*, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data). The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve. The knots will be unevenly spaced over the years such that each region will have the same proportion of data.

```{code-cell} ipython3
num_knots = 15
Expand All @@ -138,50 +120,7 @@ for knot in knot_list:
plt.gca().axvline(knot, color="grey", alpha=0.4);
```

Before doing any Bayesian modeling of the spline, we can get an idea of what our model should look like using the lowess modeling from `statsmodels`

```{code-cell} ipython3
blossom_data.plot.scatter(
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
)
for knot in knot_list:
plt.gca().axvline(knot, color="grey", alpha=0.4)

lowess = sm.nonparametric.lowess
lowess_data = lowess(blossom_data.doy, blossom_data.year, frac=0.2, it=10)
plt.plot(lowess_data[:, 0], lowess_data[:, 1], color="firebrick", lw=2);
```

Another way of visualizing what the spline should look like is to plot individual linear models over the data between each knot.
The spline will effectively be a compromise between these individual models and a continuous curve.

```{code-cell} ipython3
blossom_data["knot_group"] = [np.where(a <= knot_list)[0][0] for a in blossom_data.year]
blossom_data["knot_group"] = pd.Categorical(blossom_data["knot_group"], ordered=True)
```

```{code-cell} ipython3
blossom_data.plot.scatter(
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
)
for knot in knot_list:
plt.gca().axvline(knot, color="grey", alpha=0.4)

for i in np.arange(len(knot_list) - 1):

# Subset data to knot
x_range = (knot_list[i], knot_list[i + 1])
subset = blossom_data.query(f"year > {x_range[0]} & year <= {x_range[1]}")

# Create a linear model and predict values
lm = sm.OLS(subset.doy, sm.add_constant(subset.year, prepend=False)).fit()
x_vals = np.linspace(x_range[0], x_range[1], 100)
y_vals = lm.predict(sm.add_constant(x_vals, prepend=False))
# Add to plot
plt.plot(x_vals, y_vals, color="firebrick", lw=2)
```

Finally we can use 'patsy' to create the matrix $B$ that will be the b-spline basis for the regression.
We can use `patsy` to create the matrix $B$ that will be the b-spline basis for the regression.
The degree is set to 3 to create a cubic b-spline.

```{code-cell} ipython3
Expand All @@ -194,10 +133,7 @@ B = dmatrix(
B
```

The b-spline basis is plotted below, showing the "domain" of each piece of the spline.
The height of each curve indicates how "influential" the corresponding model covariate (one per spline region) will be on model's "inference" of that region.
(The quotes are to indicate that these words were chosen to help with interpretation and are not the proper mathematical terms.)
The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed.
The b-spline basis is plotted below, showing the *domain* of each piece of the spline. The height of each curve indicates how influential the corresponding model covariate (one per spline region) will be on model's inference of that region. The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed.

```{code-cell} ipython3
spline_df = (
Expand All @@ -217,17 +153,16 @@ plt.legend(title="Spline Index", loc="upper center", fontsize=8, ncol=6);

### Fit the model

Finally, the model can be built using PyMC3.
A graphical diagram shows the organization of the model parameters (note that this requires the installation of 'python-graphviz' which is easiest in a `conda` virtual environment).
Finally, the model can be built using PyMC. A graphical diagram shows the organization of the model parameters (note that this requires the installation of `python-graphviz`, which I recommend doing in a `conda` virtual environment).

```{code-cell} ipython3
COORDS = {"obs": np.arange(len(blossom_data.doy)), "splines": np.arange(B.shape[1])}
with pm.Model(coords=COORDS) as spline_model:
COORDS = {"splines": np.arange(B.shape[1])}
with pm.Model(coords=COORDS, rng_seeder=RANDOM_SEED) as spline_model:
a = pm.Normal("a", 100, 5)
w = pm.Normal("w", mu=0, sd=3, dims="splines")
w = pm.Normal("w", mu=0, sigma=3, size=B.shape[1], dims="splines")
mu = pm.Deterministic("mu", a + pm.math.dot(np.asarray(B, order="F"), w.T))
sigma = pm.Exponential("sigma", 1)
D = pm.Normal("D", mu, sigma, observed=blossom_data.doy, dims="obs")
D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy, dims="obs")
```

```{code-cell} ipython3
Expand All @@ -236,23 +171,16 @@ pm.model_to_graphviz(spline_model)

```{code-cell} ipython3
with spline_model:
prior_pred = pm.sample_prior_predictive(random_seed=RANDOM_SEED)
trace = pm.sample(
draws=1000,
tune=1000,
random_seed=RANDOM_SEED,
chains=4,
return_inferencedata=True,
)
post_pred = pm.sample_posterior_predictive(trace, random_seed=RANDOM_SEED)
trace.extend(az.from_pymc3(prior=prior_pred, posterior_predictive=post_pred))
idata = pm.sample_prior_predictive()
idata.extend(pm.sample(draws=1000, tune=1000, random_seed=RANDOM_SEED, chains=4))
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
```

## Analysis

Now we can analyze the draws from the posterior of the model.

### Fit parameters
### Parameter Estimates

Below is a table summarizing the posterior distributions of the model parameters.
The posteriors of $a$ and $\sigma$ are quite narrow while those for $w$ are wider.
Expand All @@ -261,25 +189,24 @@ This is likely because all of the data points are used to estimate $a$ and $\sig
The effective sample size and $\widehat{R}$ values all look good, indiciating that the model has converged and sampled well from the posterior distribution.

```{code-cell} ipython3
az.summary(trace, var_names=["a", "w", "sigma"])
az.summary(idata, var_names=["a", "w", "sigma"])
```

The trace plots of the model parameters look good (fuzzy caterpillars), further indicating that the chains converged and mixed.
The trace plots of the model parameters look good (homogeneous and no sign of trend), further indicating that the chains converged and mixed.

```{code-cell} ipython3
az.plot_trace(trace, var_names=["a", "w", "sigma"]);
az.plot_trace(idata, var_names=["a", "w", "sigma"]);
```

```{code-cell} ipython3
az.plot_forest(trace, var_names=["w"], combined=False);
az.plot_forest(idata, var_names=["w"], combined=False, r_hat=True);
```

Another visualization of the fit spline values is to plot them multiplied against the basis matrix.
The knot boundaries are shown in gray again, but now the spline basis is multipled against the values of $w$ (represented as the rainbow-colored curves).
The dot product of $B$ and $w$ – the actual computation in the linear model – is shown in blue.
The knot boundaries are shown as vertical lines again, but now the spline basis is multipled against the values of $w$ (represented as the rainbow-colored curves). The dot product of $B$ and $w$ – the actual computation in the linear model – is shown in black.

```{code-cell} ipython3
wp = trace.posterior["w"].values.mean(axis=(0, 1))
wp = idata.posterior["w"].mean(("chain", "draw")).values

spline_df = (
pd.DataFrame(B * wp.T)
Expand Down Expand Up @@ -311,7 +238,7 @@ for knot in knot_list:
Lastly, we can visualize the predictions of the model using the posterior predictive check.

```{code-cell} ipython3
post_pred = az.summary(trace, var_names=["mu"]).reset_index(drop=True)
post_pred = az.summary(idata, var_names=["mu"]).reset_index(drop=True)
blossom_data_post = blossom_data.copy().reset_index(drop=True)
blossom_data_post["pred_mean"] = post_pred["mean"]
blossom_data_post["pred_hdi_lower"] = post_pred["hdi_3%"]
Expand All @@ -325,7 +252,7 @@ blossom_data.plot.scatter(
color="cornflowerblue",
s=10,
title="Cherry blossom data with posterior predictions",
ylabel="Day of Year",
ylabel="Days in bloom",
)
for knot in knot_list:
plt.gca().axvline(knot, color="grey", alpha=0.4)
Expand All @@ -340,33 +267,20 @@ plt.fill_between(
);
```

## Authors

- Authored by Joshua Cook in October, 2021
- Updated by [Tyler James Burch](https://github.com/tjburch) in October, 2021

+++

## References

:::{bibliography}
:filter: docname in docnames
:::

I would like to recognize the discussion ["Spline Regression in PyMC3"](https://discourse.pymc.io/t/spline-regression-in-pymc3/6235) on the PyMC3 Discourse as the inspiration of this example and for the helpful discussion and problem-solving that improved it further.

+++

## Watermark

```{code-cell} ipython3
%load_ext watermark
%watermark -n -u -v -iv -w -p theano,xarray,patsy
%watermark -n -u -v -iv -w -p aesara,xarray,patsy
```

:::{include} ../page_footer.md
:::

```{code-cell} ipython3

```