14
14
15
15
import numpy as np
16
16
17
+ from pandas import DataFrame , Series
18
+
17
19
from pymc3 .distributions .distribution import NoDistribution
18
20
from pymc3 .distributions .tree import LeafNode , SplitNode , Tree
19
21
20
22
__all__ = ["BART" ]
21
23
22
24
23
25
class BaseBART (NoDistribution ):
24
- def __init__ (self , X , Y , m = 200 , alpha = 0.25 , * args , ** kwargs ):
25
- self .X = X
26
- self .Y = Y
26
+ def __init__ (self , X , Y , m = 200 , alpha = 0.25 , split_prior = None , * args , ** kwargs ):
27
+
28
+ self .X , self .Y , self .missing_data = self .preprocess_XY (X , Y )
29
+
27
30
super ().__init__ (shape = X .shape [0 ], dtype = "float64" , testval = 0 , * args , ** kwargs )
28
31
29
32
if self .X .ndim != 2 :
@@ -48,12 +51,24 @@ def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs):
48
51
49
52
self .num_observations = X .shape [0 ]
50
53
self .num_variates = X .shape [1 ]
54
+ self .available_predictors = list (range (self .num_variates ))
55
+ self .ssv = SampleSplittingVariable (split_prior , self .num_variates )
51
56
self .m = m
52
57
self .alpha = alpha
53
58
self .trees = self .init_list_of_trees ()
59
+ self .all_trees = []
54
60
self .mean = fast_mean ()
55
61
self .prior_prob_leaf_node = compute_prior_probability (alpha )
56
62
63
+ def preprocess_XY (self , X , Y ):
64
+ if isinstance (Y , (Series , DataFrame )):
65
+ Y = Y .to_numpy ()
66
+ if isinstance (X , (Series , DataFrame )):
67
+ X = X .to_numpy ()
68
+ missing_data = np .any (np .isnan (X ))
69
+ X = np .random .normal (X , np .std (X , 0 ) / 100 )
70
+ return X , Y , missing_data
71
+
57
72
def init_list_of_trees (self ):
58
73
initial_value_leaf_nodes = self .Y .mean () / self .m
59
74
initial_idx_data_points_leaf_nodes = np .array (range (self .num_observations ), dtype = "int32" )
@@ -79,39 +94,26 @@ def __iter__(self):
79
94
def __repr_latex (self ):
80
95
raise NotImplementedError
81
96
82
- def get_available_predictors (self , idx_data_points_split_node ):
83
- possible_splitting_variables = []
84
- for j in range (self .num_variates ):
85
- x_j = self .X [idx_data_points_split_node , j ]
86
- x_j = x_j [~ np .isnan (x_j )]
87
- for i in range (1 , len (x_j )):
88
- if x_j [i - 1 ] != x_j [i ]:
89
- possible_splitting_variables .append (j )
90
- break
91
- return possible_splitting_variables
92
-
93
97
def get_available_splitting_rules (self , idx_data_points_split_node , idx_split_variable ):
94
98
x_j = self .X [idx_data_points_split_node , idx_split_variable ]
95
- x_j = x_j [ ~ np . isnan ( x_j )]
96
- values , indices = np .unique (x_j , return_index = True )
97
- # The last value is not consider since if we choose it as the value of
98
- # the splitting rule assignment, it would leave the right subtree empty.
99
- return values [:- 1 ], indices [: - 1 ]
99
+ if self . missing_data :
100
+ x_j = x_j [ ~ np .isnan (x_j )]
101
+ values = np . unique ( x_j )
102
+ # The last value is never available as it would leave the right subtree empty.
103
+ return values [:- 1 ]
100
104
101
105
def grow_tree (self , tree , index_leaf_node ):
102
- # This can be unsuccessful when there are not available predictors
103
106
current_node = tree .get_node (index_leaf_node )
104
107
105
- available_predictors = self .get_available_predictors (current_node .idx_data_points )
106
-
107
- if not available_predictors :
108
- return False , None
109
-
110
- index_selected_predictor = discrete_uniform_sampler (len (available_predictors ))
111
- selected_predictor = available_predictors [index_selected_predictor ]
112
- available_splitting_rules , _ = self .get_available_splitting_rules (
108
+ index_selected_predictor = self .ssv .rvs ()
109
+ selected_predictor = self .available_predictors [index_selected_predictor ]
110
+ available_splitting_rules = self .get_available_splitting_rules (
113
111
current_node .idx_data_points , selected_predictor
114
112
)
113
+ # This can be unsuccessful when there are not available splitting rules
114
+ if available_splitting_rules .size == 0 :
115
+ return False , None
116
+
115
117
index_selected_splitting_rule = discrete_uniform_sampler (len (available_splitting_rules ))
116
118
selected_splitting_rule = available_splitting_rules [index_selected_splitting_rule ]
117
119
new_split_node = SplitNode (
@@ -167,6 +169,19 @@ def draw_leaf_value(self, idx_data_points):
167
169
draw = self .mean (R_j )
168
170
return draw
169
171
172
+ def predict (self , X_new ):
173
+ """Compute out of sample predictions evaluated at X_new"""
174
+ trees = self .all_trees
175
+ num_observations = X_new .shape [0 ]
176
+ pred = np .zeros ((len (trees ), num_observations ))
177
+ np .random .randint (len (trees ))
178
+ for draw , trees_to_sum in enumerate (trees ):
179
+ new_Y = np .zeros (num_observations )
180
+ for tree in trees_to_sum :
181
+ new_Y += [tree .predict_out_of_sample (x ) for x in X_new ]
182
+ pred [draw ] = new_Y
183
+ return pred
184
+
170
185
171
186
def compute_prior_probability (alpha ):
172
187
"""
@@ -217,6 +232,31 @@ def discrete_uniform_sampler(upper_value):
217
232
return int (np .random .random () * upper_value )
218
233
219
234
235
+ class SampleSplittingVariable :
236
+ def __init__ (self , prior , num_variates ):
237
+ self .prior = prior
238
+ self .num_variates = num_variates
239
+
240
+ if self .prior is not None :
241
+ self .prior = np .asarray (self .prior )
242
+ self .prior = self .prior / self .prior .sum ()
243
+ if self .prior .size != self .num_variates :
244
+ raise ValueError (
245
+ f"The size of split_prior ({ self .prior .size } ) should be the "
246
+ f"same as the number of covariates ({ self .num_variates } )"
247
+ )
248
+ self .enu = list (enumerate (np .cumsum (self .prior )))
249
+
250
+ def rvs (self ):
251
+ if self .prior is None :
252
+ return int (np .random .random () * self .num_variates )
253
+ else :
254
+ r = np .random .random ()
255
+ for i , v in self .enu :
256
+ if r <= v :
257
+ return i
258
+
259
+
220
260
class BART (BaseBART ):
221
261
"""
222
262
BART distribution.
@@ -225,19 +265,23 @@ class BART(BaseBART):
225
265
226
266
Parameters
227
267
----------
228
- X :
268
+ X : array-like
229
269
The design matrix.
230
- Y :
270
+ Y : array-like
231
271
The response vector.
232
272
m : int
233
273
Number of trees
234
274
alpha : float
235
275
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
236
276
altought it is recomenned to be in the interval (0, 0.5].
277
+ split_prior : array-like
278
+ Each element of split_prior should be in the [0, 1] interval and the elements should sum
279
+ to 1. Otherwise they will be normalized.
280
+ Defaults to None, all variable have the same a prior probability
237
281
"""
238
282
239
- def __init__ (self , X , Y , m = 200 , alpha = 0.25 ):
240
- super ().__init__ (X , Y , m , alpha )
283
+ def __init__ (self , X , Y , m = 200 , alpha = 0.25 , split_prior = None ):
284
+ super ().__init__ (X , Y , m , alpha , split_prior )
241
285
242
286
def _str_repr (self , name = None , dist = None , formatting = "plain" ):
243
287
if dist is None :
0 commit comments