Skip to content

update BART example #323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 98 additions & 77 deletions examples/BART/BART_introduction.ipynb

Large diffs are not rendered by default.

40 changes: 17 additions & 23 deletions myst_nbs/BART/BART_introduction.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupytext:
format_version: 0.13
jupytext_version: 1.13.7
kernelspec:
display_name: Python 3 (ipykernel)
display_name: Python 3.9.7 ('base')
language: python
name: python3
---
Expand All @@ -27,8 +27,9 @@ import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc_experimental as pmx

print(f"Running on PyMC3 v{pm.__version__}")
print(f"Running on PyMC v{pm.__version__}")
```

```{code-cell} ipython3
Expand Down Expand Up @@ -84,10 +85,10 @@ y_data = hist / 4
In PyMC a BART variable can be defined very similar to other random variables. One important difference is that we have to pass ours Xs and Ys to the BART variable. Here we are also making explicit that we are going to use a sum over 20 trees (`m=20`). Low number of trees like 20 could be good enough for simple models like this and could also work very good as a quick approximation for more complex models in particular during the iterative or explorative phase of modeling. In those cases once we have more certainty about the model we really like we can improve the approximation by increasing `m`, in the literature is common to find reports of good results with numbers like 50, 100 or 200.

```{code-cell} ipython3
with pm.Model(rng_seeder=rng) as model_coal:
μ = pm.BART("μ", X=x_data, Y=y_data, m=20)
with pm.Model() as model_coal:
μ = pmx.BART("μ", X=x_data, Y=y_data, m=20)
y_pred = pm.Poisson("y_pred", mu=pm.math.exp(μ), observed=y_data)
idata_coal = pm.sample()
idata_coal = pm.sample(random_seed=RANDOM_SEED)
```

The white line in the following plot shows the median rate of accidents. The darker orange band represent the HDI 50% and the lighter one the 94%. We can see a rapid decrease of coal accidents between 1880 and 1900. Feel free to compare these results with those in the original {ref}`pymc:pymc_overview` example.
Expand All @@ -111,15 +112,15 @@ In the previous plot the white line is the median over 4000 posterior draws, and
The following figure shows two samples from the posterior of $\mu$. We can see that these functions are not smooth. This is fine and is a direct consequence of using regression trees. Trees can be seen as a way to represent stepwise functions, and a sum of stepwise functions is just another stepwise function. Thus, when using BART we just need to know that we are assuming that a stepwise function is a good enough approximation for our problem. In practice this is often the case because we sum over many trees, usually values like 50, 100 or 200. Additionally, we often average over the posterior distribution. All this makes the "steps smoother", even when we never really have an smooth function as for example with Gaussian processes (splines). A nice theoretical result, tells us that in the limit of $m \to \infty$ the BART prior converges to a [nowheredifferentiable](https://en.wikipedia.org/wiki/Weierstrass_function) Gaussian process.

```{code-cell} ipython3
plt.step(x_data, np.exp(pm.bart.predict(idata_coal, rng, size=2).T));
plt.step(x_data, np.exp(pmx.bart.predict(idata_coal, rng, x_data, size=2).T));
```

To gain further intuition the next figures show 3 of the `m` trees. As we can see these are definitely not very good approximators by themselves. inspecting individuals trees is generally not necessary. We are just showing them here to generate intuition about BART.

```{code-cell} ipython3
bart_trees = idata_coal.sample_stats.bart_trees
for i in [0, 1, 2]:
plt.step(x_data[:, 0], bart_trees[0, 0, i].item().predict_output())
plt.step(x_data[:, 0], [bart_trees[0, 0, i].item().predict(x) for x in x_data])
```

## Biking with BART
Expand All @@ -139,11 +140,11 @@ Y = bikes["count"]
```

```{code-cell} ipython3
with pm.Model(rng_seeder=rng) as model_bikes:
with pm.Model() as model_bikes:
σ = pm.HalfNormal("σ", Y.std())
μ = pm.BART("μ", X, Y, m=50)
μ = pmx.BART("μ", X, Y, m=50)
y = pm.Normal("y", μ, σ, observed=Y)
idata_bikes = pm.sample()
idata_bikes = pm.sample(random_seed=RANDOM_SEED)
```

### Partial dependence plots
Expand All @@ -153,8 +154,7 @@ with pm.Model(rng_seeder=rng) as model_bikes:
To help us interpret the results of our model we are going to use partial dependence plot. This is a type of plot that shows the marginal effect that one covariate has on the predicted variable. That is, what is the effect that a covariate $X_i$ has of $Y$ while we average over all the other covariates ($X_j, \forall j \not = i$). This type of plot are not exclusive of BART. But they are often used in the BART literature. PyMC provides an utility function to make this plot from the inference data.

```{code-cell} ipython3
pm.bart.plot_dependence(idata_bikes, X=X, Y=Y, grid=(2, 2), var_discrete=[3]);
# plt.savefig("pdp_discrete.png", bbox_inches='tight')
pmx.bart.plot_dependence(idata_bikes, X=X, Y=Y, grid=(2, 2), var_discrete=[3]);
```

From this plot we can see the main effect of each covariate on the predicted value. This is very useful we can recover complex relationship beyond monotonic increasing or decreasing effects. For example for the `hour` covariate we can see two peaks around 8 and and 17 hs and a minimum at midnight.
Expand All @@ -171,22 +171,16 @@ As we saw in the previous section a partial dependence plot can visualize give u

The following plot shows the relative importance in a scale from 0 to 1 (less to more importance) and the sum of the individual importance is 1. See that, at least in this case, the relative importance qualitative agrees with the partial dependence plot.

Additionally, we provide a novel method to assess the variable importance. You can see an example in the bottom panel. On the x-axis we have the number of components (variables) and on the y-axis the Pearson correlation between the predictions made between the full-model (all variables included) and the restricted-models, those with only a subset of the variables in the full-model. The components are included following the relative variable importance order, as show in the top panel. Thus, in this example 1 component means `hour`, two components means `hour` and `temperature`, 3 components `hour`, `temperature`and `humidity`. Finally, four components means `hour`, `temperature`, `humidity`, `workingday`, i.e., the full model. Hence, from the next figure we can see that even a model with a single component, `hour`, is very close to the full model. Even more, the model with two components `hour`, and `temperature` is on average indistinguishable from the full model. The error bars represent the 94 \% HDI from the posterior predictive distribution. It is important to notice that to compute these correlations we do not resample the models, instead the predictions of the restricted-models are approximated from the full-model.

```{code-cell} ipython3
_, ax = plt.subplots(1)
VI = (
idata_bikes.sample_stats["variable_inclusion"]
.stack(samples=("chain", "draw"))
.mean("samples")
.values
)
ax.plot(VI / VI.sum(), "o-")
ax.set_xticks(range(4))
ax.set_xticklabels(["hour", "temperature", "humidity", "workingday"])
ax.set_ylabel("relative importance");
labels = ["hour", "temperature", "humidity", "workingday"]
pmx.bart.utils.plot_variable_importance(idata_bikes, X.values, labels, samples=100);
```

## Authors
* Authored by Osvaldo Martin in Dec, 2021 ([pymc-examples#259](https://github.com/pymc-devs/pymc-examples/pull/259))
* Updated by Osvaldo Martin in May, 2022 ([pymc-examples#323](https://github.com/pymc-devs/pymc-examples/pull/323))

+++

Expand Down