Skip to content

Commit 4f93c44

Browse files
authored
new Interrupted Time Series notebook (#444)
* initial commit * fix typo * fix wording in final sentence * add section on causal DAG * fix typo + more info in modelling section * improve how data is generated and dealt with in the model * add comment about prior predictive * add brief definition of causal impact
1 parent dce63a0 commit 4f93c44

File tree

4 files changed

+1506
-0
lines changed

4 files changed

+1506
-0
lines changed
Loading

examples/causal_inference/interrupted_time_series.ipynb

+1,129
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import daft
2+
import matplotlib.pyplot as plt
3+
4+
plt.rcParams.update({"text.usetex": True})
5+
6+
pgm = daft.PGM()
7+
pgm.add_node("t", "time", 0, 0.75, aspect=1.5)
8+
pgm.add_node("treat", "treatment", -0.75, 0, aspect=1.8)
9+
pgm.add_node("y", "outcome", 0.75, 0, aspect=1.8)
10+
pgm.add_edge("t", "y")
11+
pgm.add_edge("t", "treat")
12+
pgm.add_edge("treat", "y")
13+
pgm.render()
14+
pgm.savefig("DAG_interrupted_time_series.png", dpi=500)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.13.7
8+
kernelspec:
9+
display_name: pymc_env
10+
language: python
11+
name: pymc_env
12+
---
13+
14+
(interrupted_time_series)=
15+
# Interrupted time series analysis
16+
17+
:::{post} Oct, 2022
18+
:tags: counterfactuals, causal inference, time series, forecasting, causal impact, quasi experiments
19+
:category: intermediate
20+
:author: Benjamin T. Vincent
21+
:::
22+
23+
+++
24+
25+
This notebook focuses on how to conduct a simple Bayesian [interrupted time series analysis](https://en.wikipedia.org/wiki/Interrupted_time_series). This is useful in [quasi-experimental settings](https://en.wikipedia.org/wiki/Quasi-experiment) where an intervention was applied to all treatment units.
26+
27+
For example, if a change to a website was made and you want to know the causal impact of the website change then _if_ this change was applied selectively and randomly to a test group of website users, then you may be able to make causal claims using the [A/B testing approach](https://en.wikipedia.org/wiki/A/B_testing).
28+
29+
However, if the website change was rolled out to _all_ users of the website then you do not have a control group. In this case you do not have a direct measurement of the counterfactual, what _would have happened if_ the website change was not made. In this case, if you have data over a 'good' number of time points, then you may be able to make use of the interrupted time series approach.
30+
31+
Interested readers are directed to the excellent textbook [The Effect](https://theeffectbook.net/) {cite:p}`huntington2021effect`. Chapter 17 covers 'event studies' which the author prefers to the interrupted time series terminology.
32+
33+
+++
34+
35+
## Causal DAG
36+
37+
A simple causal DAG for the interrupted time series is given below, but see {cite:p}`huntington2021effect` for a more general DAG. In short it says:
38+
39+
* The outcome is causally influenced by time (e.g. other factors that change over time) and by the treatment.
40+
* The treatment is causally influenced by time.
41+
42+
![](DAG_interrupted_time_series.png)
43+
44+
Intuitively, we could describe the logic of the approach as:
45+
* We know that the outcome varies over time.
46+
* If we build a model of how the outcome varies over time _before_ the treatment, then we can predit the counterfactual of what we would expect to happen _if_ the treatment had not occurred.
47+
* We can compare this counterfactual with the observations from the time of the intervention onwards. If there is a meaningful discrepancy then we can attribute this as a causal impact of the intervention.
48+
49+
This is reasonable if we have ruled out other plausible causes occurring at the same point in time as (or after) the intervention. This becomes more tricky to justify the more time has passed since the intervention because it is more likely that other relevant events maye have occurred that could provide alternative causal explanations.
50+
51+
If this does not make sense immediately, I recommend checking the example data figure below then revisiting this section.
52+
53+
```{code-cell} ipython3
54+
import arviz as az
55+
import matplotlib.dates as mdates
56+
import matplotlib.pyplot as plt
57+
import numpy as np
58+
import pandas as pd
59+
import pymc as pm
60+
import xarray as xr
61+
62+
from scipy.stats import norm
63+
```
64+
65+
```{code-cell} ipython3
66+
%config InlineBackend.figure_format = 'retina'
67+
RANDOM_SEED = 8927
68+
rng = np.random.default_rng(RANDOM_SEED)
69+
az.style.use("arviz-darkgrid")
70+
```
71+
72+
Now let's define some helper functions
73+
74+
```{code-cell} ipython3
75+
:tags: [hide-cell]
76+
77+
def format_x_axis(ax, minor=False):
78+
# major ticks
79+
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y %b"))
80+
ax.xaxis.set_major_locator(mdates.YearLocator())
81+
ax.grid(which="major", linestyle="-", axis="x")
82+
# minor ticks
83+
if minor:
84+
ax.xaxis.set_minor_formatter(mdates.DateFormatter("%Y %b"))
85+
ax.xaxis.set_minor_locator(mdates.MonthLocator())
86+
ax.grid(which="minor", linestyle=":", axis="x")
87+
# rotate labels
88+
for label in ax.get_xticklabels(which="both"):
89+
label.set(rotation=70, horizontalalignment="right")
90+
91+
92+
def plot_xY(x, Y, ax):
93+
quantiles = Y.quantile((0.025, 0.25, 0.5, 0.75, 0.975), dim=("chain", "draw")).transpose()
94+
95+
az.plot_hdi(
96+
x,
97+
hdi_data=quantiles.sel(quantile=[0.025, 0.975]),
98+
fill_kwargs={"alpha": 0.25},
99+
smooth=False,
100+
ax=ax,
101+
)
102+
az.plot_hdi(
103+
x,
104+
hdi_data=quantiles.sel(quantile=[0.25, 0.75]),
105+
fill_kwargs={"alpha": 0.5},
106+
smooth=False,
107+
ax=ax,
108+
)
109+
ax.plot(x, quantiles.sel(quantile=0.5), color="C1", lw=3)
110+
111+
112+
# default figure sizes
113+
figsize = (10, 5)
114+
```
115+
116+
## Generate data
117+
118+
The focus of this example is on making causal claims using the interrupted time series approach. Therefore we will work with some relatively simple synthetic data which only requires a very simple model.
119+
120+
```{code-cell} ipython3
121+
:tags: []
122+
123+
treatment_time = "2017-01-01"
124+
β0 = 0
125+
β1 = 0.1
126+
dates = pd.date_range(
127+
start=pd.to_datetime("2010-01-01"), end=pd.to_datetime("2020-01-01"), freq="M"
128+
)
129+
N = len(dates)
130+
131+
132+
def causal_effect(df):
133+
return (df.index > treatment_time) * 2
134+
135+
136+
df = (
137+
pd.DataFrame()
138+
.assign(time=np.arange(N), date=dates)
139+
.set_index("date", drop=True)
140+
.assign(y=lambda x: β0 + β1 * x.time + causal_effect(x) + norm(0, 0.5).rvs(N))
141+
)
142+
df
143+
```
144+
145+
```{code-cell} ipython3
146+
# Split into pre and post intervention dataframes
147+
pre = df[df.index < treatment_time]
148+
post = df[df.index >= treatment_time]
149+
```
150+
151+
```{code-cell} ipython3
152+
fig, ax = plt.subplots()
153+
ax = pre["y"].plot(label="pre")
154+
post["y"].plot(ax=ax, label="post")
155+
ax.axvline(treatment_time, c="k", ls=":")
156+
plt.legend();
157+
```
158+
159+
In this simple dataset, we have a noisy linear trend upwards, and because this data is synthetic we know that we have a step increase in the outcome at the intervention time, and this effect is persistent over time.
160+
161+
+++
162+
163+
## Modelling
164+
Here we build a simple linear model. Remember that we are building a model of the pre-intervention data with the goal that it would do a reasonable job of forecasting what would have happened if the intervention had not been applied. Put another way, we are _not_ modelling any aspect of the post-intervention observations such as a change in intercept, slope or whether the effect is transient or permenent.
165+
166+
```{code-cell} ipython3
167+
with pm.Model() as model:
168+
# observed predictors and outcome
169+
time = pm.MutableData("time", pre["time"].to_numpy(), dims="obs_id")
170+
# priors
171+
beta0 = pm.Normal("beta0", 0, 1)
172+
beta1 = pm.Normal("beta1", 0, 0.2)
173+
# the actual linear model
174+
mu = pm.Deterministic("mu", beta0 + (beta1 * time), dims="obs_id")
175+
sigma = pm.HalfNormal("sigma", 2)
176+
# likelihood
177+
pm.Normal("obs", mu=mu, sigma=sigma, observed=pre["y"].to_numpy(), dims="obs_id")
178+
```
179+
180+
```{code-cell} ipython3
181+
pm.model_to_graphviz(model)
182+
```
183+
184+
## Prior predictive check
185+
186+
As part of the Bayesian workflow, we will plot our prior predictions to see what outcomes the model finds before having observed any data.
187+
188+
```{code-cell} ipython3
189+
with model:
190+
idata = pm.sample_prior_predictive(random_seed=RANDOM_SEED)
191+
192+
fig, ax = plt.subplots(figsize=figsize)
193+
194+
plot_xY(pre.index, idata.prior_predictive["obs"], ax)
195+
format_x_axis(ax)
196+
ax.plot(pre.index, pre["y"], label="observed")
197+
ax.set(title="Prior predictive distribution in the pre intervention era")
198+
plt.legend();
199+
```
200+
201+
This seems reasonable in that the priors over the intercept and slope are broad enough to lead to predicted observations which easily contain the actual data. This means that the particular priors chosen will not unduly constrain the posterior parameter estimates.
202+
203+
+++
204+
205+
## Inference
206+
Draw samples for the posterior distribution, and remember we are doing this for the pre intervention data only.
207+
208+
```{code-cell} ipython3
209+
with model:
210+
idata.extend(pm.sample(random_seed=RANDOM_SEED))
211+
```
212+
213+
```{code-cell} ipython3
214+
az.plot_trace(idata, var_names=["~mu"]);
215+
```
216+
217+
## Posterior predictive check
218+
219+
Another important aspect of the Bayesian workflow is to plot the model's posterior predictions, allowing us to see how well the model can retrodict the already observed data. It is at this point that we can decide whether the model is too simple (then we'd build more complexity into the model) or if it's fine.
220+
221+
```{code-cell} ipython3
222+
with model:
223+
idata.extend(pm.sample_posterior_predictive(idata, random_seed=RANDOM_SEED))
224+
225+
fig, ax = plt.subplots(figsize=figsize)
226+
227+
az.plot_hdi(pre.index, idata.posterior_predictive["obs"], hdi_prob=0.5, smooth=False)
228+
az.plot_hdi(pre.index, idata.posterior_predictive["obs"], hdi_prob=0.95, smooth=False)
229+
ax.plot(pre.index, pre["y"], label="observed")
230+
format_x_axis(ax)
231+
ax.set(title="Posterior predictive distribution in the pre intervention era")
232+
plt.legend();
233+
```
234+
235+
The next step is not strictly necessary, but we can calculate the difference between the model retrodictions and the data to look at the errors. This can be useful to identify any unexpected inability to retrodict pre-intervention data.
236+
237+
```{code-cell} ipython3
238+
:tags: [hide-input]
239+
240+
# convert outcome into an XArray object with a labelled dimension to help in the next step
241+
y = xr.DataArray(pre["y"].to_numpy(), dims=["obs_id"])
242+
243+
# do the calculation by taking the difference
244+
excess = y - idata.posterior_predictive["obs"]
245+
```
246+
247+
```{code-cell} ipython3
248+
fig, ax = plt.subplots(figsize=figsize)
249+
# the transpose is to keep arviz happy, ordering the dimensions as (chain, draw, time)
250+
az.plot_hdi(pre.index, excess.transpose(..., "obs_id"), hdi_prob=0.5, smooth=False)
251+
az.plot_hdi(pre.index, excess.transpose(..., "obs_id"), hdi_prob=0.95, smooth=False)
252+
format_x_axis(ax)
253+
ax.axhline(y=0, color="k")
254+
ax.set(title="Residuals, pre intervention");
255+
```
256+
257+
## Counterfactual inference
258+
Now we will use our model to predict the observed outcome in the 'what if?' scenario of no intervention.
259+
260+
So we update the model with the `time` data from the `post` intervention dataframe and run posterior predictive sampling to predict the observations we would observe in this counterfactual scenario. We could also call this 'forecasting'.
261+
262+
```{code-cell} ipython3
263+
with model:
264+
pm.set_data(
265+
{
266+
"time": post["time"].to_numpy(),
267+
}
268+
)
269+
counterfactual = pm.sample_posterior_predictive(
270+
idata, var_names=["obs"], random_seed=RANDOM_SEED
271+
)
272+
```
273+
274+
```{code-cell} ipython3
275+
:tags: [hide-input]
276+
277+
fig, ax = plt.subplots(figsize=figsize)
278+
279+
plot_xY(post.index, counterfactual.posterior_predictive["obs"], ax)
280+
format_x_axis(ax, minor=False)
281+
ax.plot(post.index, post["y"], label="observed")
282+
ax.set(
283+
title="Counterfactual: Posterior predictive forecast of outcome if intervention not taken place"
284+
)
285+
plt.legend();
286+
```
287+
288+
We now have the ingredients needed to calculate the causal impact. This is simply the difference between the Bayesian counterfactual predictions and the observations.
289+
290+
+++
291+
292+
## Causal impact: since the intervention
293+
294+
+++
295+
296+
Now we'll use the predicted outcome under the counterfactual scenario and compare that to the observed outcome to come up with our counterfactual estimate.
297+
298+
```{code-cell} ipython3
299+
# convert outcome into an XArray object with a labelled dimension to help in the next step
300+
outcome = xr.DataArray(post["y"].to_numpy(), dims=["obs_id"])
301+
302+
# do the calculation by taking the difference
303+
excess = outcome - counterfactual.posterior_predictive["obs"]
304+
```
305+
306+
And we can easily compute the cumulative causal impact
307+
308+
```{code-cell} ipython3
309+
# calculate the cumulative causal impact
310+
cumsum = excess.cumsum(dim="obs_id")
311+
```
312+
313+
```{code-cell} ipython3
314+
:tags: [hide-input]
315+
316+
fig, ax = plt.subplots(2, 1, figsize=(figsize[0], 9), sharex=True)
317+
318+
# Plot the excess
319+
# The transpose is to keep arviz happy, ordering the dimensions as (chain, draw, t)
320+
plot_xY(post.index, excess.transpose(..., "obs_id"), ax[0])
321+
format_x_axis(ax[0], minor=True)
322+
ax[0].axhline(y=0, color="k")
323+
ax[0].set(title="Causal impact, since intervention")
324+
325+
# Plot the cumulative excess
326+
plot_xY(post.index, cumsum.transpose(..., "obs_id"), ax[1])
327+
format_x_axis(ax[1], minor=False)
328+
ax[1].axhline(y=0, color="k")
329+
ax[1].set(title="Cumulative causal impact, since intervention");
330+
```
331+
332+
And there we have it - we've done some Bayesian counterfactual inference in PyMC using the interrupted time series approach! In just a few steps we've:
333+
- Built a simple model to predict a time series.
334+
- Inferred the model parameters based on pre intervention data, running prior and posterior predictive checks. We note that the model is pretty good.
335+
- Used the model to create counterfactual predictions of what would happen after the intervention time if the intervention had not occurred.
336+
- Calculated the causal impact (and cumulative causal impact) by comparing the observed outcome to our counterfactual expected outcome in the case of no intervention.
337+
338+
There are of course many ways that the interrupted time series approach could be more involved in real world settings. For example there could be more temporal structure, such as seasonality. If so then we might want to use a specific time series model, not just a linear regression model. There could also be additional informative predictor variables to incorporate into the model. Additionally some designs do not just consist of pre and post intervention periods (also known as A/B designs), but could also involve a period where the intervention is inactive, active, and then inactive (also known as an ABA design).
339+
340+
+++
341+
342+
## References
343+
344+
:::{bibliography}
345+
:filter: docname in docnames
346+
:::
347+
348+
+++
349+
350+
## Authors
351+
- Authored by [Benjamin T. Vincent](https://github.com/drbenvincent) in October 2022.
352+
353+
+++
354+
355+
## Watermark
356+
357+
```{code-cell} ipython3
358+
%load_ext watermark
359+
%watermark -n -u -v -iv -w -p aesara,aeppl,xarray
360+
```
361+
362+
:::{include} ../page_footer.md
363+
:::

0 commit comments

Comments
 (0)