Skip to content

Commit a3c2060

Browse files
authored
Bart: Refactor splitting variables and predictions (#4310)
* refactor split variables, add prior split variables, add predict function * remove dirichlet * remove unused import * use to_numpy to convert a series/dataframe to an array. * use already defined variable
1 parent 7ef2de4 commit a3c2060

File tree

3 files changed

+108
-40
lines changed

3 files changed

+108
-40
lines changed

pymc3/distributions/bart.py

+76-32
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@
1414

1515
import numpy as np
1616

17+
from pandas import DataFrame, Series
18+
1719
from pymc3.distributions.distribution import NoDistribution
1820
from pymc3.distributions.tree import LeafNode, SplitNode, Tree
1921

2022
__all__ = ["BART"]
2123

2224

2325
class BaseBART(NoDistribution):
24-
def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs):
25-
self.X = X
26-
self.Y = Y
26+
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):
27+
28+
self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y)
29+
2730
super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs)
2831

2932
if self.X.ndim != 2:
@@ -48,12 +51,24 @@ def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs):
4851

4952
self.num_observations = X.shape[0]
5053
self.num_variates = X.shape[1]
54+
self.available_predictors = list(range(self.num_variates))
55+
self.ssv = SampleSplittingVariable(split_prior, self.num_variates)
5156
self.m = m
5257
self.alpha = alpha
5358
self.trees = self.init_list_of_trees()
59+
self.all_trees = []
5460
self.mean = fast_mean()
5561
self.prior_prob_leaf_node = compute_prior_probability(alpha)
5662

63+
def preprocess_XY(self, X, Y):
64+
if isinstance(Y, (Series, DataFrame)):
65+
Y = Y.to_numpy()
66+
if isinstance(X, (Series, DataFrame)):
67+
X = X.to_numpy()
68+
missing_data = np.any(np.isnan(X))
69+
X = np.random.normal(X, np.std(X, 0) / 100)
70+
return X, Y, missing_data
71+
5772
def init_list_of_trees(self):
5873
initial_value_leaf_nodes = self.Y.mean() / self.m
5974
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
@@ -79,39 +94,26 @@ def __iter__(self):
7994
def __repr_latex(self):
8095
raise NotImplementedError
8196

82-
def get_available_predictors(self, idx_data_points_split_node):
83-
possible_splitting_variables = []
84-
for j in range(self.num_variates):
85-
x_j = self.X[idx_data_points_split_node, j]
86-
x_j = x_j[~np.isnan(x_j)]
87-
for i in range(1, len(x_j)):
88-
if x_j[i - 1] != x_j[i]:
89-
possible_splitting_variables.append(j)
90-
break
91-
return possible_splitting_variables
92-
9397
def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable):
9498
x_j = self.X[idx_data_points_split_node, idx_split_variable]
95-
x_j = x_j[~np.isnan(x_j)]
96-
values, indices = np.unique(x_j, return_index=True)
97-
# The last value is not consider since if we choose it as the value of
98-
# the splitting rule assignment, it would leave the right subtree empty.
99-
return values[:-1], indices[:-1]
99+
if self.missing_data:
100+
x_j = x_j[~np.isnan(x_j)]
101+
values = np.unique(x_j)
102+
# The last value is never available as it would leave the right subtree empty.
103+
return values[:-1]
100104

101105
def grow_tree(self, tree, index_leaf_node):
102-
# This can be unsuccessful when there are not available predictors
103106
current_node = tree.get_node(index_leaf_node)
104107

105-
available_predictors = self.get_available_predictors(current_node.idx_data_points)
106-
107-
if not available_predictors:
108-
return False, None
109-
110-
index_selected_predictor = discrete_uniform_sampler(len(available_predictors))
111-
selected_predictor = available_predictors[index_selected_predictor]
112-
available_splitting_rules, _ = self.get_available_splitting_rules(
108+
index_selected_predictor = self.ssv.rvs()
109+
selected_predictor = self.available_predictors[index_selected_predictor]
110+
available_splitting_rules = self.get_available_splitting_rules(
113111
current_node.idx_data_points, selected_predictor
114112
)
113+
# This can be unsuccessful when there are not available splitting rules
114+
if available_splitting_rules.size == 0:
115+
return False, None
116+
115117
index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules))
116118
selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule]
117119
new_split_node = SplitNode(
@@ -167,6 +169,19 @@ def draw_leaf_value(self, idx_data_points):
167169
draw = self.mean(R_j)
168170
return draw
169171

172+
def predict(self, X_new):
173+
"""Compute out of sample predictions evaluated at X_new"""
174+
trees = self.all_trees
175+
num_observations = X_new.shape[0]
176+
pred = np.zeros((len(trees), num_observations))
177+
np.random.randint(len(trees))
178+
for draw, trees_to_sum in enumerate(trees):
179+
new_Y = np.zeros(num_observations)
180+
for tree in trees_to_sum:
181+
new_Y += [tree.predict_out_of_sample(x) for x in X_new]
182+
pred[draw] = new_Y
183+
return pred
184+
170185

171186
def compute_prior_probability(alpha):
172187
"""
@@ -217,6 +232,31 @@ def discrete_uniform_sampler(upper_value):
217232
return int(np.random.random() * upper_value)
218233

219234

235+
class SampleSplittingVariable:
236+
def __init__(self, prior, num_variates):
237+
self.prior = prior
238+
self.num_variates = num_variates
239+
240+
if self.prior is not None:
241+
self.prior = np.asarray(self.prior)
242+
self.prior = self.prior / self.prior.sum()
243+
if self.prior.size != self.num_variates:
244+
raise ValueError(
245+
f"The size of split_prior ({self.prior.size}) should be the "
246+
f"same as the number of covariates ({self.num_variates})"
247+
)
248+
self.enu = list(enumerate(np.cumsum(self.prior)))
249+
250+
def rvs(self):
251+
if self.prior is None:
252+
return int(np.random.random() * self.num_variates)
253+
else:
254+
r = np.random.random()
255+
for i, v in self.enu:
256+
if r <= v:
257+
return i
258+
259+
220260
class BART(BaseBART):
221261
"""
222262
BART distribution.
@@ -225,19 +265,23 @@ class BART(BaseBART):
225265
226266
Parameters
227267
----------
228-
X :
268+
X : array-like
229269
The design matrix.
230-
Y :
270+
Y : array-like
231271
The response vector.
232272
m : int
233273
Number of trees
234274
alpha : float
235275
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
236276
altought it is recomenned to be in the interval (0, 0.5].
277+
split_prior : array-like
278+
Each element of split_prior should be in the [0, 1] interval and the elements should sum
279+
to 1. Otherwise they will be normalized.
280+
Defaults to None, all variable have the same a prior probability
237281
"""
238282

239-
def __init__(self, X, Y, m=200, alpha=0.25):
240-
super().__init__(X, Y, m, alpha)
283+
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
284+
super().__init__(X, Y, m, alpha, split_prior)
241285

242286
def _str_repr(self, name=None, dist=None, formatting="plain"):
243287
if dist is None:

pymc3/distributions/tree.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ def predict_output(self, num_observations):
8484
output[current_node.idx_data_points] = current_node.value
8585
return output
8686

87+
def predict_out_of_sample(self, x):
88+
"""
89+
Predict output of tree for an unobserved point x.
90+
91+
Parameters
92+
----------
93+
x : numpy array
94+
95+
Returns
96+
-------
97+
float
98+
Value of the leaf value where the unobserved point lies.
99+
"""
100+
leaf_node = self._traverse_tree(x=x, node_index=0)
101+
return leaf_node.value
102+
87103
def _traverse_tree(self, x, node_index=0):
88104
"""
89105
Traverse the tree starting from a particular node given an unobserved point.
@@ -99,15 +115,13 @@ def _traverse_tree(self, x, node_index=0):
99115
"""
100116
current_node = self.get_node(node_index)
101117
if isinstance(current_node, SplitNode):
102-
if x is not np.NaN:
118+
if x[current_node.idx_split_variable] <= current_node.split_value:
103119
left_child = current_node.get_idx_left_child()
104-
final_node = self._traverse_tree(x, left_child)
120+
current_node = self._traverse_tree(x, left_child)
105121
else:
106122
right_child = current_node.get_idx_right_child()
107-
final_node = self._traverse_tree(x, right_child)
108-
else:
109-
final_node = current_node
110-
return final_node
123+
current_node = self._traverse_tree(x, right_child)
124+
return current_node
111125

112126
def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node):
113127
"""

pymc3/step_methods/pgbart.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,13 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
6464

6565
self.tune = True
6666
self.idx = 0
67+
self.iter = 0
68+
self.sum_trees = []
69+
self.chunk = chunk
70+
6771
if chunk == "auto":
6872
self.chunk = max(1, int(self.bart.m * 0.1))
73+
self.bart.chunk = self.chunk
6974
self.num_particles = num_particles
7075
self.log_num_particles = np.log(num_particles)
7176
self.indices = list(range(1, num_particles))
@@ -96,14 +101,14 @@ def astep(self, _):
96101
self.idx = 0
97102

98103
for idx in range(self.idx, self.idx + self.chunk):
99-
if idx > bart.m:
104+
if idx >= bart.m:
100105
break
101106
self.idx += 1
102107
tree = bart.trees[idx]
103108
R_j = bart.get_residuals_loo(tree)
104109
# Generate an initial set of SMC particles
105110
# at the end of the algorithm we return one of these particles as the new tree
106-
particles = self.init_particles(tree.tree_id, R_j, bart.num_observations)
111+
particles = self.init_particles(tree.tree_id, R_j, num_observations)
107112

108113
for t in range(1, max_stages):
109114
# Get old particle at stage t
@@ -147,6 +152,11 @@ def astep(self, _):
147152
bart.sum_trees_output = bart.Y - R_j + new_prediction
148153

149154
if not self.tune:
155+
self.iter += 1
156+
self.sum_trees.append(new_tree.tree)
157+
if not self.iter % bart.m:
158+
bart.all_trees.append(self.sum_trees)
159+
self.sum_trees = []
150160
for index in new_tree.used_variates:
151161
variable_inclusion[index] += 1
152162

0 commit comments

Comments
 (0)