Skip to content

Commit 5e92706

Browse files
aloctavodiaalporter08
authored andcommitted
1 parent dff27a5 commit 5e92706

7 files changed

+730
-217
lines changed

examples/bart/bart_heteroscedasticity.ipynb

+130-28
Large diffs are not rendered by default.

examples/bart/bart_heteroscedasticity.myst.md

+15-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
---
22
jupytext:
3+
formats: ipynb,md
34
text_representation:
45
extension: .md
56
format_name: myst
@@ -81,7 +82,7 @@ Next, we specify the model. Note that we just need one BART distribution which c
8182

8283
```{code-cell} ipython3
8384
with pm.Model() as model_marketing_full:
84-
w = pmb.BART("w", X=X, Y=np.log(Y), m=200, shape=(2, n_obs))
85+
w = pmb.BART("w", X=X, Y=np.log(Y), m=100, shape=(2, n_obs))
8586
y = pm.Gamma("y", mu=pm.math.exp(w[0]), sigma=pm.math.exp(w[1]), observed=Y)
8687
8788
pm.model_to_graphviz(model=model_marketing_full)
@@ -91,7 +92,7 @@ We now fit the model.
9192

9293
```{code-cell} ipython3
9394
with model_marketing_full:
94-
idata_marketing_full = pm.sample(random_seed=rng)
95+
idata_marketing_full = pm.sample(2000, random_seed=rng, compute_convergence_checks=False)
9596
posterior_predictive_marketing_full = pm.sample_posterior_predictive(
9697
trace=idata_marketing_full, random_seed=rng
9798
)
@@ -104,7 +105,7 @@ We can now visualize the posterior predictive distribution of the mean and the l
104105
```{code-cell} ipython3
105106
posterior_mean = idata_marketing_full.posterior["w"].mean(dim=("chain", "draw"))[0]
106107
107-
w_hdi = az.hdi(ary=idata_marketing_full, group="posterior", var_names=["w"])
108+
w_hdi = az.hdi(ary=idata_marketing_full, group="posterior", var_names=["w"], hdi_prob=0.5)
108109
109110
pps = az.extract(
110111
posterior_predictive_marketing_full, group="posterior_predictive", var_names=["y"]
@@ -116,14 +117,19 @@ idx = np.argsort(X[:, 0])
116117
117118
118119
fig, ax = plt.subplots()
119-
az.plot_hdi(x=X[:, 0], y=pps, ax=ax, fill_kwargs={"alpha": 0.3, "label": r"Likelihood $94\%$ HDI"})
120+
az.plot_hdi(
121+
x=X[:, 0],
122+
y=pps,
123+
ax=ax,
124+
hdi_prob=0.90,
125+
fill_kwargs={"alpha": 0.3, "label": r"Observations $90\%$ HDI"},
126+
)
120127
az.plot_hdi(
121128
x=X[:, 0],
122129
hdi_data=np.exp(w_hdi["w"].sel(w_dim_0=0)),
123130
ax=ax,
124-
fill_kwargs={"alpha": 0.6, "label": r"Mean $94\%$ HDI"},
131+
fill_kwargs={"alpha": 0.6, "label": r"Mean $50\%$ HDI"},
125132
)
126-
ax.plot(X[:, 0][idx], np.exp(posterior_mean[idx]), c="black", lw=3, label="Posterior Mean")
127133
ax.plot(df["youtube"], df["sales"], "o", c="C0", label="Raw Data")
128134
ax.legend(loc="upper left")
129135
ax.set(
@@ -138,8 +144,9 @@ The fit looks good! In fact, we see that the mean and variance increase as a fun
138144
+++
139145

140146
## Authors
141-
- Authored by [Juan Orduz](https://juanitorduz.github.io/) in February 2023
142-
- Rerun by Osvaldo Martin in March 2023
147+
- Authored by [Juan Orduz](https://juanitorduz.github.io/) in Feb, 2023
148+
- Rerun by Osvaldo Martin in Mar, 2023
149+
- Rerun by Osvaldo Martin in Nov, 2023
143150

144151
+++
145152

examples/bart/bart_introduction.ipynb

+381-129
Large diffs are not rendered by default.

examples/bart/bart_introduction.myst.md

+27-22
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
---
22
jupytext:
3+
formats: ipynb,md
34
text_representation:
45
extension: .md
56
format_name: myst
67
format_version: 0.13
78
kernelspec:
8-
display_name: pymc-examples-env
9+
display_name: Python 3 (ipykernel)
910
language: python
1011
name: python3
1112
---
@@ -115,32 +116,25 @@ ax.set_xlabel("years")
115116
ax.set_ylabel("rate");
116117
```
117118

118-
The white line in the following plot shows the median rate of accidents. The darker orange band represent the HDI 50% and the lighter one the 94%. We can see a rapid decrease of coal accidents between 1880 and 1900. Feel free to compare these results with those in the original {ref}`pymc:pymc_overview` example.
119+
The white line in the following plot shows the median rate of accidents. The darker orange band represents the HDI 50% and the lighter one the 94%. We can see a rapid decrease in coal accidents between 1880 and 1900. Feel free to compare these results with those in the original {ref}`pymc:pymc_overview` example.
119120

120-
In the previous plot the white line is the mean over 4000 posterior draws, and each one of those posterior draws is a sum over `m=20` trees.
121+
In the previous plot, the white line is the mean over 4000 posterior draws, and each one of those posterior draws is a sum over `m=20` trees.
121122

122123

123-
The following figure shows two samples from the posterior of $\mu$. We can see that these functions are not smooth. This is fine and is a direct consequence of using regression trees. Trees can be seen as a way to represent stepwise functions, and a sum of stepwise functions is just another stepwise function. Thus, when using BART we just need to know that we are assuming that a stepwise function is a good enough approximation for our problem. In practice this is often the case because we sum over many trees, usually values like 50, 100 or 200. Additionally, we often average over the posterior distribution. All this makes the "steps smoother", even when we never really have an smooth function as for example with Gaussian processes (splines). A nice theoretical result, tells us that in the limit of $m \to \infty$ the BART prior converges to a [nowheredifferentiable](https://en.wikipedia.org/wiki/Weierstrass_function) Gaussian process.
124+
The following figure shows two samples from the posterior of $\mu$. We can see that these functions are not smooth. This is fine and is a direct consequence of using regression trees. Trees can be seen as a way to represent stepwise functions, and a sum of stepwise functions is just another stepwise function. Thus, when using BART we need to know that we are assuming that a stepwise function is a good enough approximation for our problem. In practice, this is often the case because we sum over many trees, usually values like 50, 100, or 200.
125+
Additionally, we often average over the posterior distribution. All this makes the "steps smoother", even when we never really have a smooth function for example with Gaussian processes or splines. A nice theoretical result, tells us that in the limit of $m \to \infty$ the BART prior converges to a [nowheredifferentiable](https://en.wikipedia.org/wiki/Weierstrass_function) Gaussian process.
124126

125127
The following figure shows two samples of $\mu$ from the posterior.
126128

127129
```{code-cell} ipython3
128130
plt.step(x_data, rates.sel(chain=0, draw=[3, 10]).T);
129131
```
130132

131-
The next figure shows 3 trees. As we can see these are very simple function and definitely not very good approximators by themselves. Inspecting individuals trees is generally not necessary when working with BART, we are showing them just so we can gain further intuition on the inner workings of BART.
132-
133-
```{code-cell} ipython3
134-
bart_trees = μ_.owner.op.all_trees
135-
for i in [0, 1, 2]:
136-
plt.step(x_data[:, 0], [bart_trees[0][i].predict(x) for x in x_data])
137-
```
138-
139133
## Biking with BART
140134

141135
+++
142136

143-
To explore other features offered by BART in PyMC. We are now going to move on to a different example. In this example we have data about the number of bikes rental in a city, and we have chosen four covariates; the hour of the day, the temperature, the humidity and whether is a workingday or a weekend. This dataset is a subset of the [bike_sharing_dataset](http://archive.ics.uci.edu/ml/datasets/Bike+Sharing+Dataset).
137+
To explore other features offered by PyMC-BART. We are now going to move on to a different example. In this example, we have data about the number of bike rentals in a city, and we have chosen four covariates; the `hour` of the day, the `temperature`, the `humidity`, and whether is a `workingday` or a weekend. This dataset is a subset of the [bike_sharing_dataset](http://archive.ics.uci.edu/ml/datasets/Bike+Sharing+Dataset).
144138

145139
```{code-cell} ipython3
146140
try:
@@ -189,32 +183,42 @@ We instead consider checking the convergence of BART variables an important part
189183

190184
+++
191185

192-
To help us interpret the results of our model we are going to use partial dependence plots. This is a type of plot that shows the marginal effect that one covariate has on the predicted variable. That is, what is the effect that a covariate $X_i$ has of $Y$ while we average over all the other covariates ($X_j, \forall j \not = i$). This type of plot are not exclusive of BART. But they are often used in the BART literature. PyMC-BART provides an utility function to make this plot from the inference data.
186+
To help us interpret the results of our model we are going to use partial dependence plots. This is a type of plot that shows the marginal effect that one covariate has on the predicted variable. That is, what is the effect that a covariate $X_i$ has of $Y$ while we average over all the other covariates ($X_j, \forall j \not = i$). This type of plot is not exclusive to BART. But they are often used in the BART literature. PyMC-BART provides a utility function to make this plot from a BART random variable.
193187

194188
```{code-cell} ipython3
195-
pmb.plot_dependence(μ, X=X, Y=Y, grid=(2, 2), func=np.exp);
189+
pmb.plot_pdp(μ, X=X, Y=Y, grid=(2, 2), func=np.exp, var_discrete=[3]);
196190
```
197191

198-
From this plot we can see the main effect of each covariate on the predicted value. This is very useful as we can recover complex relationship beyond monotonic increasing or decreasing effects. For example for the `hour` covariate we can see two peaks around 8 and and 17 hs and a minimum at midnight.
192+
From this plot, we can see the main effect of each covariate on the predicted value. This is very useful as we can recover complex relationships beyond monotonic increasing or decreasing effects. For example for the `hour` covariate we can see two peaks around 8 and 17 hs and a minimum at midnight.
199193

200-
When interpreting partial dependence plots we should be careful about the assumptions in this plot. First we are assuming variables are independent. For example when computing the effect of `hour` we have to marginalize the effect of `temperature` and this means that to compute the partial dependence value at `hour=0` we are including all observed values of temperature, and this may include temperatures that are actually not observed at midnight, given that lower temperatures are more likely than higher ones. We are seeing only averages, so if for a covariate half the values are positively associated with predicted variable and the other half negatively associated. The partial dependence plot will be flat as their contributions will cancel each other out. This is a problem that can be solved by using individual conditional expectation plots `pmb.plot_dependence(..., kind="ice")`. Notice that all this assumptions are assumptions of the partial dependence plot, not of our model! In fact BART can easily accommodate interaction of variables Although the prior in BART regularizes high order interactions). For more on interpreting Machine Learning model you could check the "Interpretable Machine Learning" book {cite:p}`molnar2019`.
194+
When interpreting partial dependence plots we should be careful about the assumptions in this plot. First, we assume that variables are independent. For example when computing the effect of `hour` we have to marginalize the effect of `temperature` and this means that to compute the partial dependence value at `hour=0` we are including all observed values of temperature, and this may include temperatures that are not observed at midnight, given that lower temperatures are more likely than higher ones. We are seeing only averages, so if for a covariate half the values are positively associated with the predicted variable and the other half negatively associated. The partial dependence plot will be flat as their contributions will cancel each other out. This is a problem that can be solved by using individual conditional expectation plots `pmb.plot_ice(...)`. Notice that all these assumptions are assumptions of the partial dependence plot, not of our model! In fact, BART can easily accommodate interactions of variables Although the prior in BART regularizes high-order interactions). For more on interpreting Machine Learning models, you could check the "Interpretable Machine Learning" book {cite:p}`molnar2019`.
201195

202-
Finally like with other regression methods we should be careful that the effects we are seeing on individual variables are conditional on the inclusion of the other variables. So for example, while `humidity` seems to be mostly flat, meaning that this covariate has an small effect of the number of used bikes. This could be the case because `humidity` and `temperature` are correlated to some extend and once we include `temperature` in our model `humidity` does not provide too much extra information. Try for example fitting the model again but this time with `humidity` as the single covariate and then fitting the model again with `hour` as a single covariate. You should see that the result for this single-variate models will very similar to the previous figure for the `hour` covariate, but less similar for the `humidity` covariate.
196+
Finally, like with other regression methods, we should be careful that the effects we are seeing on individual variables are conditional on the inclusion of the other variables. So for example, while `humidity` seems to be mostly flat, meaning that this covariate has a small effect on the number of used bikes. This could be the case because `humidity` and `temperature` are correlated to some extent and once we include `temperature` in our model `humidity` does not provide too much extra information. Try for example fitting the model again but this time with `humidity` as the single covariate and then fitting the model again with `hour` as a single covariate. You should see that the result for this single-variate model will be very similar to the previous figure for the `hour` covariate, but less similar for the `humidity` covariate.
203197

204198
+++
205199

206200
### Variable importance
207201

208-
As we saw in the previous section a partial dependence plot can visualize give us an idea of how much each covariable contributes to the predicted outcome. But BART itself leads to a simple heuristic to estimate variable importance. That is simple count how many times a variable is included in all the regression trees. The intuition is that if a variable is important they it should appears more often in the fitted trees that less important variables. While this heuristic seems to provide reasonable results in practice, there is not too much theory justifying this procedure, at least not yet.
202+
As we saw in the previous section a partial dependence plot can visualize and give us an idea of how much each covariable contributes to the predicted outcome. Moreover, PyMC-BART provides a novel method to assess the importance of each variable in the model. You can see an example in the following figure.
209203

210-
The following plot shows the relative importance in a scale from 0 to 1 (less to more importance) and the sum of the individual importance is 1. See that, at least in this case, the relative importance qualitative agrees with the partial dependence plot.
204+
On the x-axis we have the number of covariables and on the y-axis R² (the the square of the Pearson correlation coefficient) between the predictions made for the full model (all variables included) and the restricted models, those with only a subset of the variables.
211205

212-
Additionally, PyMC-BART provides a novel method to assess the variable importance. You can see an example in the bottom panel. On the x-axis we have the number of covariables and on the y-axis the square of the Pearson correlation coefficient between the predictions made for the full-model (all variables included) and the restricted-models, those with only a subset of the variables. The components are included following the relative variable importance order, as show in the top panel. Thus, in this example "number of covariables" is 1 `hour`, 2 `hour` and `temperature`, 3 `hour`, `temperature`and `humidity`. Finally, 4 means `hour`, `temperature`, `humidity`, `workingday`, i.e., the full model. Hence, from the next figure we can see that even a model with a single component, `hour`, is very close to the full model. Even more, the model with two components `hour`, and `temperature` is on average indistinguishable from the full model. The error bars represent the 94 \% HDI from the posterior predictive distribution. It is important to notice that to compute these correlations we do not resample the models, instead the predictions of the restricted-models are approximated by *prunning* variables from the full-model.
206+
In this example, the most important variable is `hour`, then `temperature`, `humidity`, and finally `workingday`. Notice that the first value of R², is the value of a model that only includes the variable `hour`, the second R² is for a model with two variables, `hour` and `temperature`, and so on. Besides this ranking, we can see that even a model with a single component, `hour`, is very close to the full model. Even more, the model with two components `hour`, and `temperature` is on average indistinguishable from the full model. The error bars represent the 94 \% HDI from the posterior predictive distribution. This means that we should expect a model with only `hour` and `temperature` to have a similar predictice performance than a model with the four variables, `hour`, `temperature`, `humidity`, and `workingday`.
213207

214208
```{code-cell} ipython3
215-
pmb.plot_variable_importance(idata_bikes, μ, X, samples=100);
209+
pmb.plot_variable_importance(idata_bikes, μ, X);
216210
```
217211

212+
`plot_variable_importance` is fast because it makes two assumptions:
213+
214+
* The ranking of the variables is computed with a simple heuristic. We just count how many times a variable is included in all the regression trees. The intuition is that if a variable is important it should appear more often in the fitted trees than less important variables.
215+
216+
* The predictions used for the computation of R² come from the already fitted trees. For instance to estimate the effect of a BART model with the variable `hour` we *prune* the branch that does not include this variable. This makes computations much faster, as we do not need to find a new set of trees.
217+
218+
Instead of using the "counting heuristic". It can also perform a backward search, `pmb.plot_variable_importance(..., method="backward")`. Internally this will compute the R² for the full model, then for all models with one variable less than the full model, and then for all models with less than two, and so on. At each stage, we discard the variable that gives the lowest R². The backward method will be slower, as we need to compute predictions for more models.
219+
220+
+++
221+
218222
### Out-of-Sample Predictions
219223

220224
In this section we want to show how to do out-of-sample predictions with BART. We are going to use the same dataset as before, but this time we are going to split the data into a training and a test set. We are going to use the training set to fit the model and the test set to evaluate the model.
@@ -400,6 +404,7 @@ This plot helps us understand the reason behind the bad performance on the test
400404
* Updated by Osvaldo Martin in Nov, 2022
401405
* Juan Orduz added out-of-sample section in Jan, 2023
402406
* Updated by Osvaldo Martin in Mar, 2023
407+
* Updated by Osvaldo Martin in Nov, 2023
403408

404409
+++
405410

examples/bart/bart_quantile_regression.ipynb

+173-28
Large diffs are not rendered by default.

examples/bart/bart_quantile_regression.myst.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
---
22
jupytext:
3+
formats: ipynb,md
34
text_representation:
45
extension: .md
56
format_name: myst
@@ -146,7 +147,8 @@ We can see that when we use a Normal likelihood, and from that fit we compute th
146147

147148
## Authors
148149
* Authored by Osvaldo Martin in Jan, 2023
149-
* Rerun by Osvaldo Martin in March 2023
150+
* Rerun by Osvaldo Martin in Mar, 2023
151+
* Rerun by Osvaldo Martin in Nov, 2023
150152

151153
+++
152154

sphinxext/thumbnail_extractor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112
"time_series": "Time Series",
113113
"spatial": "Spatial Analysis",
114114
"diagnostics_and_criticism": "Diagnostics and Model Criticism",
115-
"bart": "Bayesian Additive Regressive Trees",
115+
"bart": "Bayesian Additive Regression Trees",
116116
"mixture_models": "Mixture Models",
117117
"survival_analysis": "Survival Analysis",
118118
"ode_models": "ODE models",

0 commit comments

Comments
 (0)