Skip to content

Commit e51b9d3

Browse files
Add Bayesian Additive Regression Trees (BARTs) (#4183)
* update from master * black * minor fix * clean code * blackify * fix error residuals * use a low number of max_stages for the first iteration, remove not necessary errors * use Rockova prior, refactor prior leaf prob computaion * clean code add docstring * reduce code * speed-up by fitting a subset of trees per step * choose max * improve docstrings * refactor and clean code * clean docstrings * add tests and minor fixes. Co-authored-by: aloctavodia <[email protected]> Co-authored-by: jmloyola <[email protected]> * remove space. Co-authored-by: aloctavodia <[email protected]> Co-authored-by: jmloyola <[email protected]> * add variable importance report * use ValueError * wip return mean and std variable importance * update variable importance report * update release notes, remove vi hdi report * test variable importance * fix test Co-authored-by: jmloyola <[email protected]>
1 parent f732a01 commit e51b9d3

File tree

10 files changed

+750
-6
lines changed

10 files changed

+750
-6
lines changed

RELEASE-NOTES.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
- `sample_posterior_predictive_w` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#4042](https://github.com/pymc-devs/pymc3/pull/4042))
1717
- Added `pymc3.gp.cov.Circular` kernel for Gaussian Processes on circular domains, e.g. the unit circle (see [#4082](https://github.com/pymc-devs/pymc3/pull/4082)).
1818
- Add MLDA, a new stepper for multilevel sampling. MLDA can be used when a hierarchy of approximate posteriors of varying accuracy is available, offering improved sampling efficiency especially in high-dimensional problems and/or where gradients are not available (see [#3926](https://github.com/pymc-devs/pymc3/pull/3926))
19-
- Change SMC metropolis kernel to independent metropolis kernel [#4115](https://github.com/pymc-devs/pymc3/pull/3926))
19+
- Change SMC metropolis kernel to independent metropolis kernel [#4115](https://github.com/pymc-devs/pymc3/pull/4115))
2020
- Add alternative parametrization to NegativeBinomial distribution in terms of n and p (see [#4126](https://github.com/pymc-devs/pymc3/issues/4126))
21+
- Add Bayesian Additive Regression Trees (BARTs) [#4183](https://github.com/pymc-devs/pymc3/pull/4183))
2122
- Added a new `MixtureSameFamily` distribution to handle mixtures of arbitrary dimensions in vectorized form (see [#4185](https://github.com/pymc-devs/pymc3/issues/4185)).
2223

2324

25+
2426
## PyMC3 3.9.3 (11 August 2020)
2527

2628
### Maintenance

pymc3/distributions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,11 @@
9999
from .timeseries import MvGaussianRandomWalk
100100
from .timeseries import MvStudentTRandomWalk
101101

102+
from .bart import BART
103+
102104
from .bound import Bound
103105

106+
104107
__all__ = [
105108
"Uniform",
106109
"Flat",
@@ -177,4 +180,5 @@
177180
"Moyal",
178181
"Simulator",
179182
"fast_sample_posterior_predictive",
183+
"BART",
180184
]

pymc3/distributions/bart.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
from .distribution import NoDistribution
17+
from .tree import Tree, SplitNode, LeafNode
18+
19+
__all__ = ["BART"]
20+
21+
22+
class BaseBART(NoDistribution):
23+
def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs):
24+
self.X = X
25+
self.Y = Y
26+
super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs)
27+
28+
if self.X.ndim != 2:
29+
raise ValueError("The design matrix X must have two dimensions")
30+
31+
if self.Y.ndim != 1:
32+
raise ValueError("The response matrix Y must have one dimension")
33+
if self.X.shape[0] != self.Y.shape[0]:
34+
raise ValueError(
35+
"The design matrix X and the response matrix Y must have the same number of elements"
36+
)
37+
if not isinstance(m, int):
38+
raise ValueError("The number of trees m type must be int")
39+
if m < 1:
40+
raise ValueError("The number of trees m must be greater than zero")
41+
42+
if alpha <= 0 or 1 <= alpha:
43+
raise ValueError(
44+
"The value for the alpha parameter for the tree structure "
45+
"must be in the interval (0, 1)"
46+
)
47+
48+
self.num_observations = X.shape[0]
49+
self.num_variates = X.shape[1]
50+
self.m = m
51+
self.alpha = alpha
52+
self.trees = self.init_list_of_trees()
53+
self.mean = fast_mean()
54+
self.prior_prob_leaf_node = compute_prior_probability(alpha)
55+
56+
def init_list_of_trees(self):
57+
initial_value_leaf_nodes = self.Y.mean() / self.m
58+
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
59+
list_of_trees = []
60+
for i in range(self.m):
61+
new_tree = Tree.init_tree(
62+
tree_id=i,
63+
leaf_node_value=initial_value_leaf_nodes,
64+
idx_data_points=initial_idx_data_points_leaf_nodes,
65+
)
66+
list_of_trees.append(new_tree)
67+
# Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J.
68+
# bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
69+
# The sum_trees_output will contain the sum of the predicted output for all trees.
70+
# When R_j is needed we subtract the current predicted output for tree T_j.
71+
self.sum_trees_output = np.full_like(self.Y, self.Y.mean())
72+
73+
return list_of_trees
74+
75+
def __iter__(self):
76+
return iter(self.trees)
77+
78+
def __repr_latex(self):
79+
raise NotImplementedError
80+
81+
def get_available_predictors(self, idx_data_points_split_node):
82+
possible_splitting_variables = []
83+
for j in range(self.num_variates):
84+
x_j = self.X[idx_data_points_split_node, j]
85+
x_j = x_j[~np.isnan(x_j)]
86+
for i in range(1, len(x_j)):
87+
if x_j[i - 1] != x_j[i]:
88+
possible_splitting_variables.append(j)
89+
break
90+
return possible_splitting_variables
91+
92+
def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable):
93+
x_j = self.X[idx_data_points_split_node, idx_split_variable]
94+
x_j = x_j[~np.isnan(x_j)]
95+
values, indices = np.unique(x_j, return_index=True)
96+
# The last value is not consider since if we choose it as the value of
97+
# the splitting rule assignment, it would leave the right subtree empty.
98+
return values[:-1], indices[:-1]
99+
100+
def grow_tree(self, tree, index_leaf_node):
101+
# This can be unsuccessful when there are not available predictors
102+
current_node = tree.get_node(index_leaf_node)
103+
104+
available_predictors = self.get_available_predictors(current_node.idx_data_points)
105+
106+
if not available_predictors:
107+
return False, None
108+
109+
index_selected_predictor = discrete_uniform_sampler(len(available_predictors))
110+
selected_predictor = available_predictors[index_selected_predictor]
111+
available_splitting_rules, _ = self.get_available_splitting_rules(
112+
current_node.idx_data_points, selected_predictor
113+
)
114+
index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules))
115+
selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule]
116+
new_split_node = SplitNode(
117+
index=index_leaf_node,
118+
idx_split_variable=selected_predictor,
119+
split_value=selected_splitting_rule,
120+
)
121+
122+
left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points(
123+
new_split_node, current_node.idx_data_points
124+
)
125+
126+
left_node_value = self.draw_leaf_value(left_node_idx_data_points)
127+
right_node_value = self.draw_leaf_value(right_node_idx_data_points)
128+
129+
new_left_node = LeafNode(
130+
index=current_node.get_idx_left_child(),
131+
value=left_node_value,
132+
idx_data_points=left_node_idx_data_points,
133+
)
134+
new_right_node = LeafNode(
135+
index=current_node.get_idx_right_child(),
136+
value=right_node_value,
137+
idx_data_points=right_node_idx_data_points,
138+
)
139+
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)
140+
141+
return True, index_selected_predictor
142+
143+
def get_new_idx_data_points(self, current_split_node, idx_data_points):
144+
idx_split_variable = current_split_node.idx_split_variable
145+
split_value = current_split_node.split_value
146+
147+
left_idx = self.X[idx_data_points, idx_split_variable] <= split_value
148+
left_node_idx_data_points = idx_data_points[left_idx]
149+
right_node_idx_data_points = idx_data_points[~left_idx]
150+
151+
return left_node_idx_data_points, right_node_idx_data_points
152+
153+
def get_residuals(self):
154+
"""Compute the residuals."""
155+
R_j = self.Y - self.sum_trees_output
156+
return R_j
157+
158+
def get_residuals_loo(self, tree):
159+
"""Compute the residuals without leaving the passed tree out."""
160+
R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations))
161+
return R_j
162+
163+
def draw_leaf_value(self, idx_data_points):
164+
""" Draw the residual mean."""
165+
R_j = self.get_residuals()[idx_data_points]
166+
draw = self.mean(R_j)
167+
return draw
168+
169+
170+
def compute_prior_probability(alpha):
171+
"""
172+
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
173+
Taken from equation 19 in [Rockova2018].
174+
175+
Parameters
176+
----------
177+
alpha : float
178+
179+
Returns
180+
-------
181+
list with probabilities for leaf nodes
182+
183+
References
184+
----------
185+
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
186+
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
187+
"""
188+
prior_leaf_prob = [0]
189+
depth = 1
190+
while prior_leaf_prob[-1] < 1:
191+
prior_leaf_prob.append(1 - alpha ** depth)
192+
depth += 1
193+
return prior_leaf_prob
194+
195+
196+
def fast_mean():
197+
"""If available use Numba to speed up the computation of the mean."""
198+
try:
199+
from numba import jit
200+
except ImportError:
201+
return np.mean
202+
203+
@jit
204+
def mean(a):
205+
count = a.shape[0]
206+
suma = 0
207+
for i in range(count):
208+
suma += a[i]
209+
return suma / count
210+
211+
return mean
212+
213+
214+
def discrete_uniform_sampler(upper_value):
215+
"""Draw from the uniform distribution with bounds [0, upper_value)."""
216+
return int(np.random.random() * upper_value)
217+
218+
219+
class BART(BaseBART):
220+
"""
221+
BART distribution.
222+
223+
Distribution representing a sum over trees
224+
225+
Parameters
226+
----------
227+
X :
228+
The design matrix.
229+
Y :
230+
The response vector.
231+
m : int
232+
Number of trees
233+
alpha : float
234+
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
235+
altought it is recomenned to be in the interval (0, 0.5].
236+
"""
237+
238+
def __init__(self, X, Y, m=200, alpha=0.25):
239+
super().__init__(X, Y, m, alpha)
240+
241+
def _str_repr(self, name=None, dist=None, formatting="plain"):
242+
if dist is None:
243+
dist = self
244+
X = (type(self.X),)
245+
Y = (type(self.Y),)
246+
alpha = self.alpha
247+
m = self.m
248+
249+
if formatting == "latex":
250+
return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$"
251+
else:
252+
return f"{name} ~ BART(alpha = {alpha}, m = {m})"

0 commit comments

Comments
 (0)