jupytext | kernelspec | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
(spline)=
:::{post} June 4, 2022 :tags: patsy, regression, spline :category: beginner :author: Joshua Cook :::
+++ {"tags": []}
Often, the model we want to fit is not a perfect line between some
The spline is effectively multiple individual lines, each fit to a different section of
Below is a full working example of how to fit a spline using PyMC. The data and model are taken from Statistical Rethinking 2e by Richard McElreath's {cite:p}mcelreath2018statistical
.
For more information on this method of non-linear modeling, I suggesting beginning with chapter 5 of Bayesian Modeling and Computation in Python {cite:p}martin2021bayesian
.
from pathlib import Path
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
from patsy import dmatrix
%matplotlib inline
%config InlineBackend.figure_format = "retina"
RANDOM_SEED = 8927
az.style.use("arviz-darkgrid")
The data for this example is the number of days (doy
for "days of year") that the cherry trees were in bloom in each year (year
).
For convenience, years missing a doy
were dropped (which is a bad idea to deal with missing data in general!).
try:
blossom_data = pd.read_csv(Path("..", "data", "cherry_blossoms.csv"), sep=";")
except FileNotFoundError:
blossom_data = pd.read_csv(pm.get_data("cherry_blossoms.csv"), sep=";")
blossom_data.dropna().describe()
blossom_data = blossom_data.dropna(subset=["doy"]).reset_index(drop=True)
blossom_data.head(n=10)
After dropping rows with missing data, there are 827 years with the numbers of days in which the trees were in bloom.
blossom_data.shape
If we visualize the data, it is clear that there a lot of annual variation, but some evidence for a non-linear trend in bloom days over time.
blossom_data.plot.scatter(
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Days in bloom"
);
+++ {"tags": []}
We will fit the following model.
The number of days of bloom
The spline will have 15 knots, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data). The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve. The knots will be unevenly spaced over the years such that each region will have the same proportion of data.
num_knots = 15
knot_list = np.quantile(blossom_data.year, np.linspace(0, 1, num_knots))
knot_list
Below is a plot of the locations of the knots over the data.
blossom_data.plot.scatter(
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
)
for knot in knot_list:
plt.gca().axvline(knot, color="grey", alpha=0.4);
We can use patsy
to create the matrix
:tags: [hide-output]
B = dmatrix(
"bs(year, knots=knots, degree=3, include_intercept=True) - 1",
{"year": blossom_data.year.values, "knots": knot_list[1:-1]},
)
B
The b-spline basis is plotted below, showing the domain of each piece of the spline. The height of each curve indicates how influential the corresponding model covariate (one per spline region) will be on model's inference of that region. The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed.
spline_df = (
pd.DataFrame(B)
.assign(year=blossom_data.year.values)
.melt("year", var_name="spline_i", value_name="value")
)
color = plt.cm.magma(np.linspace(0, 0.80, len(spline_df.spline_i.unique())))
fig = plt.figure()
for i, c in enumerate(color):
subset = spline_df.query(f"spline_i == {i}")
subset.plot("year", "value", c=c, ax=plt.gca(), label=i)
plt.legend(title="Spline Index", loc="upper center", fontsize=8, ncol=6);
Finally, the model can be built using PyMC. A graphical diagram shows the organization of the model parameters (note that this requires the installation of python-graphviz
, which I recommend doing in a conda
virtual environment).
COORDS = {"splines": np.arange(B.shape[1])}
with pm.Model(coords=COORDS) as spline_model:
a = pm.Normal("a", 100, 5)
w = pm.Normal("w", mu=0, sigma=3, size=B.shape[1], dims="splines")
mu = pm.Deterministic("mu", a + pm.math.dot(np.asarray(B, order="F"), w.T))
sigma = pm.Exponential("sigma", 1)
D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy, dims="obs")
pm.model_to_graphviz(spline_model)
with spline_model:
idata = pm.sample_prior_predictive()
idata.extend(pm.sample(draws=1000, tune=1000, random_seed=RANDOM_SEED, chains=4))
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
Now we can analyze the draws from the posterior of the model.
Below is a table summarizing the posterior distributions of the model parameters.
The posteriors of
az.summary(idata, var_names=["a", "w", "sigma"])
The trace plots of the model parameters look good (homogeneous and no sign of trend), further indicating that the chains converged and mixed.
az.plot_trace(idata, var_names=["a", "w", "sigma"]);
az.plot_forest(idata, var_names=["w"], combined=False, r_hat=True);
Another visualization of the fit spline values is to plot them multiplied against the basis matrix.
The knot boundaries are shown as vertical lines again, but now the spline basis is multiplied against the values of
wp = idata.posterior["w"].mean(("chain", "draw")).values
spline_df = (
pd.DataFrame(B * wp.T)
.assign(year=blossom_data.year.values)
.melt("year", var_name="spline_i", value_name="value")
)
spline_df_merged = (
pd.DataFrame(np.dot(B, wp.T))
.assign(year=blossom_data.year.values)
.melt("year", var_name="spline_i", value_name="value")
)
color = plt.cm.rainbow(np.linspace(0, 1, len(spline_df.spline_i.unique())))
fig = plt.figure()
for i, c in enumerate(color):
subset = spline_df.query(f"spline_i == {i}")
subset.plot("year", "value", c=c, ax=plt.gca(), label=i)
spline_df_merged.plot("year", "value", c="black", lw=2, ax=plt.gca())
plt.legend(title="Spline Index", loc="lower center", fontsize=8, ncol=6)
for knot in knot_list:
plt.gca().axvline(knot, color="grey", alpha=0.4);
Lastly, we can visualize the predictions of the model using the posterior predictive check.
post_pred = az.summary(idata, var_names=["mu"]).reset_index(drop=True)
blossom_data_post = blossom_data.copy().reset_index(drop=True)
blossom_data_post["pred_mean"] = post_pred["mean"]
blossom_data_post["pred_hdi_lower"] = post_pred["hdi_3%"]
blossom_data_post["pred_hdi_upper"] = post_pred["hdi_97%"]
blossom_data.plot.scatter(
"year",
"doy",
color="cornflowerblue",
s=10,
title="Cherry blossom data with posterior predictions",
ylabel="Days in bloom",
)
for knot in knot_list:
plt.gca().axvline(knot, color="grey", alpha=0.4)
blossom_data_post.plot("year", "pred_mean", ax=plt.gca(), lw=3, color="firebrick")
plt.fill_between(
blossom_data_post.year,
blossom_data_post.pred_hdi_lower,
blossom_data_post.pred_hdi_upper,
color="firebrick",
alpha=0.4,
);
:::{bibliography} :filter: docname in docnames :::
+++
- Created by Joshua Cook
- Updated by Tyler James Burch
- Updated by Chris Fonnesbeck
+++
%load_ext watermark
%watermark -n -u -v -iv -w -p aesara,xarray,patsy
:::{include} ../page_footer.md :::