forked from pymc-devs/pymc
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbart.py
297 lines (242 loc) · 10.4 KB
/
bart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from pandas import DataFrame, Series
from pymc3.distributions.distribution import NoDistribution
from pymc3.distributions.tree import LeafNode, SplitNode, Tree
__all__ = ["BART"]
class BaseBART(NoDistribution):
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):
self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y)
super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs)
if self.X.ndim != 2:
raise ValueError("The design matrix X must have two dimensions")
if self.Y.ndim != 1:
raise ValueError("The response matrix Y must have one dimension")
if self.X.shape[0] != self.Y.shape[0]:
raise ValueError(
"The design matrix X and the response matrix Y must have the same number of elements"
)
if not isinstance(m, int):
raise ValueError("The number of trees m type must be int")
if m < 1:
raise ValueError("The number of trees m must be greater than zero")
if alpha <= 0 or 1 <= alpha:
raise ValueError(
"The value for the alpha parameter for the tree structure "
"must be in the interval (0, 1)"
)
self.num_observations = X.shape[0]
self.num_variates = X.shape[1]
self.available_predictors = list(range(self.num_variates))
self.ssv = sample_splitting_variable(split_prior, self.num_variates)
self.m = m
self.alpha = alpha
self.trees = self.init_list_of_trees()
self.all_trees = []
self.mean = fast_mean()
self.prior_prob_leaf_node = compute_prior_probability(alpha)
def preprocess_XY(self, X, Y):
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
X = X.to_numpy()
missing_data = np.any(np.isnan(X))
X = np.random.normal(X, np.std(X, 0) / 100)
return X, Y, missing_data
def init_list_of_trees(self):
initial_value_leaf_nodes = self.Y.mean() / self.m
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
list_of_trees = []
for i in range(self.m):
new_tree = Tree.init_tree(
tree_id=i,
leaf_node_value=initial_value_leaf_nodes,
idx_data_points=initial_idx_data_points_leaf_nodes,
)
list_of_trees.append(new_tree)
# Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J.
# bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
# The sum_trees_output will contain the sum of the predicted output for all trees.
# When R_j is needed we subtract the current predicted output for tree T_j.
self.sum_trees_output = np.full_like(self.Y, self.Y.mean())
return list_of_trees
def __iter__(self):
return iter(self.trees)
def __repr_latex(self):
raise NotImplementedError
def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable):
x_j = self.X[idx_data_points_split_node, idx_split_variable]
if self.missing_data:
x_j = x_j[~np.isnan(x_j)]
values = np.unique(x_j)
# The last value is never available as it would leave the right subtree empty.
return values[:-1]
def grow_tree(self, tree, index_leaf_node):
current_node = tree.get_node(index_leaf_node)
index_selected_predictor = self.ssv.rvs()
selected_predictor = self.available_predictors[index_selected_predictor]
available_splitting_rules = self.get_available_splitting_rules(
current_node.idx_data_points, selected_predictor
)
# This can be unsuccessful when there are not available splitting rules
if available_splitting_rules.size == 0:
return False, None
index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules))
selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule]
new_split_node = SplitNode(
index=index_leaf_node,
idx_split_variable=selected_predictor,
split_value=selected_splitting_rule,
)
left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points(
new_split_node, current_node.idx_data_points
)
left_node_value = self.draw_leaf_value(left_node_idx_data_points)
right_node_value = self.draw_leaf_value(right_node_idx_data_points)
new_left_node = LeafNode(
index=current_node.get_idx_left_child(),
value=left_node_value,
idx_data_points=left_node_idx_data_points,
)
new_right_node = LeafNode(
index=current_node.get_idx_right_child(),
value=right_node_value,
idx_data_points=right_node_idx_data_points,
)
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)
return True, index_selected_predictor
def get_new_idx_data_points(self, current_split_node, idx_data_points):
idx_split_variable = current_split_node.idx_split_variable
split_value = current_split_node.split_value
left_idx = self.X[idx_data_points, idx_split_variable] <= split_value
left_node_idx_data_points = idx_data_points[left_idx]
right_node_idx_data_points = idx_data_points[~left_idx]
return left_node_idx_data_points, right_node_idx_data_points
def get_residuals(self):
"""Compute the residuals."""
R_j = self.Y - self.sum_trees_output
return R_j
def get_residuals_loo(self, tree):
"""Compute the residuals without leaving the passed tree out."""
R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations))
return R_j
def draw_leaf_value(self, idx_data_points):
""" Draw the residual mean."""
R_j = self.get_residuals()[idx_data_points]
draw = self.mean(R_j)
return draw
def predict(self, X_new):
"""Compute out of sample predictions evaluated at X_new"""
trees = self.all_trees
num_observations = X_new.shape[0]
pred = np.zeros((len(trees), num_observations))
for draw, trees_to_sum in enumerate(trees):
new_Y = np.zeros(X_new.shape[0])
for tree in trees_to_sum:
new_Y += [tree.predict_out_of_sample(x) for x in X_new]
pred[draw] = new_Y
return pred
def compute_prior_probability(alpha):
"""
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
Taken from equation 19 in [Rockova2018].
Parameters
----------
alpha : float
Returns
-------
list with probabilities for leaf nodes
References
----------
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
"""
prior_leaf_prob = [0]
depth = 1
while prior_leaf_prob[-1] < 1:
prior_leaf_prob.append(1 - alpha ** depth)
depth += 1
return prior_leaf_prob
def fast_mean():
"""If available use Numba to speed up the computation of the mean."""
try:
from numba import jit
except ImportError:
return np.mean
@jit
def mean(a):
count = a.shape[0]
suma = 0
for i in range(count):
suma += a[i]
return suma / count
return mean
def discrete_uniform_sampler(upper_value):
"""Draw from the uniform distribution with bounds [0, upper_value)."""
return int(np.random.random() * upper_value)
class sample_splitting_variable:
def __init__(self, prior, num_variates):
self.prior = prior
self.num_variates = num_variates
if self.prior is not None:
self.prior = np.asarray(self.prior)
self.prior = self.prior / self.prior.sum()
if self.prior.size != self.num_variates:
raise ValueError(
f"The size of split_prior ({self.prior.size}) should be the "
f"same as the number of covariates ({self.num_variates})"
)
self.enu = list(enumerate(np.cumsum(self.prior)))
def rvs(self):
if self.prior is None:
return int(np.random.random() * self.num_variates)
else:
r = np.random.random()
for i, v in self.enu:
if r <= v:
return i
class BART(BaseBART):
"""
BART distribution.
Distribution representing a sum over trees
Parameters
----------
X : array-like
The design matrix.
Y : array-like
The response vector.
m : int
Number of trees
alpha : float
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
altought it is recomenned to be in the interval (0, 0.5].
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum
to 1. Otherwise they will be normalized.
Defaults to None, all variable have the same a prior probability
"""
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
super().__init__(X, Y, m, alpha, split_prior)
def _str_repr(self, name=None, dist=None, formatting="plain"):
if dist is None:
dist = self
X = (type(self.X),)
Y = (type(self.Y),)
alpha = self.alpha
m = self.m
if "latex" in formatting:
return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$"
else:
return f"{name} ~ BART(alpha = {alpha}, m = {m})"