|
| 1 | +# Copyright 2020 The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +from .distribution import NoDistribution |
| 17 | +from .tree import Tree, SplitNode, LeafNode |
| 18 | + |
| 19 | +__all__ = ["BART"] |
| 20 | + |
| 21 | + |
| 22 | +class BaseBART(NoDistribution): |
| 23 | + def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs): |
| 24 | + self.X = X |
| 25 | + self.Y = Y |
| 26 | + super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs) |
| 27 | + |
| 28 | + if self.X.ndim != 2: |
| 29 | + raise ValueError("The design matrix X must have two dimensions") |
| 30 | + |
| 31 | + if self.Y.ndim != 1: |
| 32 | + raise ValueError("The response matrix Y must have one dimension") |
| 33 | + if self.X.shape[0] != self.Y.shape[0]: |
| 34 | + raise ValueError( |
| 35 | + "The design matrix X and the response matrix Y must have the same number of elements" |
| 36 | + ) |
| 37 | + if not isinstance(m, int): |
| 38 | + raise ValueError("The number of trees m type must be int") |
| 39 | + if m < 1: |
| 40 | + raise ValueError("The number of trees m must be greater than zero") |
| 41 | + |
| 42 | + if alpha <= 0 or 1 <= alpha: |
| 43 | + raise ValueError( |
| 44 | + "The value for the alpha parameter for the tree structure " |
| 45 | + "must be in the interval (0, 1)" |
| 46 | + ) |
| 47 | + |
| 48 | + self.num_observations = X.shape[0] |
| 49 | + self.num_variates = X.shape[1] |
| 50 | + self.m = m |
| 51 | + self.alpha = alpha |
| 52 | + self.trees = self.init_list_of_trees() |
| 53 | + self.mean = fast_mean() |
| 54 | + self.prior_prob_leaf_node = compute_prior_probability(alpha) |
| 55 | + |
| 56 | + def init_list_of_trees(self): |
| 57 | + initial_value_leaf_nodes = self.Y.mean() / self.m |
| 58 | + initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32") |
| 59 | + list_of_trees = [] |
| 60 | + for i in range(self.m): |
| 61 | + new_tree = Tree.init_tree( |
| 62 | + tree_id=i, |
| 63 | + leaf_node_value=initial_value_leaf_nodes, |
| 64 | + idx_data_points=initial_idx_data_points_leaf_nodes, |
| 65 | + ) |
| 66 | + list_of_trees.append(new_tree) |
| 67 | + # Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J. |
| 68 | + # bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013 |
| 69 | + # The sum_trees_output will contain the sum of the predicted output for all trees. |
| 70 | + # When R_j is needed we subtract the current predicted output for tree T_j. |
| 71 | + self.sum_trees_output = np.full_like(self.Y, self.Y.mean()) |
| 72 | + |
| 73 | + return list_of_trees |
| 74 | + |
| 75 | + def __iter__(self): |
| 76 | + return iter(self.trees) |
| 77 | + |
| 78 | + def __repr_latex(self): |
| 79 | + raise NotImplementedError |
| 80 | + |
| 81 | + def get_available_predictors(self, idx_data_points_split_node): |
| 82 | + possible_splitting_variables = [] |
| 83 | + for j in range(self.num_variates): |
| 84 | + x_j = self.X[idx_data_points_split_node, j] |
| 85 | + x_j = x_j[~np.isnan(x_j)] |
| 86 | + for i in range(1, len(x_j)): |
| 87 | + if x_j[i - 1] != x_j[i]: |
| 88 | + possible_splitting_variables.append(j) |
| 89 | + break |
| 90 | + return possible_splitting_variables |
| 91 | + |
| 92 | + def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable): |
| 93 | + x_j = self.X[idx_data_points_split_node, idx_split_variable] |
| 94 | + x_j = x_j[~np.isnan(x_j)] |
| 95 | + values, indices = np.unique(x_j, return_index=True) |
| 96 | + # The last value is not consider since if we choose it as the value of |
| 97 | + # the splitting rule assignment, it would leave the right subtree empty. |
| 98 | + return values[:-1], indices[:-1] |
| 99 | + |
| 100 | + def grow_tree(self, tree, index_leaf_node): |
| 101 | + # This can be unsuccessful when there are not available predictors |
| 102 | + current_node = tree.get_node(index_leaf_node) |
| 103 | + |
| 104 | + available_predictors = self.get_available_predictors(current_node.idx_data_points) |
| 105 | + |
| 106 | + if not available_predictors: |
| 107 | + return False, None |
| 108 | + |
| 109 | + index_selected_predictor = discrete_uniform_sampler(len(available_predictors)) |
| 110 | + selected_predictor = available_predictors[index_selected_predictor] |
| 111 | + available_splitting_rules, _ = self.get_available_splitting_rules( |
| 112 | + current_node.idx_data_points, selected_predictor |
| 113 | + ) |
| 114 | + index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules)) |
| 115 | + selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule] |
| 116 | + new_split_node = SplitNode( |
| 117 | + index=index_leaf_node, |
| 118 | + idx_split_variable=selected_predictor, |
| 119 | + split_value=selected_splitting_rule, |
| 120 | + ) |
| 121 | + |
| 122 | + left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points( |
| 123 | + new_split_node, current_node.idx_data_points |
| 124 | + ) |
| 125 | + |
| 126 | + left_node_value = self.draw_leaf_value(left_node_idx_data_points) |
| 127 | + right_node_value = self.draw_leaf_value(right_node_idx_data_points) |
| 128 | + |
| 129 | + new_left_node = LeafNode( |
| 130 | + index=current_node.get_idx_left_child(), |
| 131 | + value=left_node_value, |
| 132 | + idx_data_points=left_node_idx_data_points, |
| 133 | + ) |
| 134 | + new_right_node = LeafNode( |
| 135 | + index=current_node.get_idx_right_child(), |
| 136 | + value=right_node_value, |
| 137 | + idx_data_points=right_node_idx_data_points, |
| 138 | + ) |
| 139 | + tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node) |
| 140 | + |
| 141 | + return True, index_selected_predictor |
| 142 | + |
| 143 | + def get_new_idx_data_points(self, current_split_node, idx_data_points): |
| 144 | + idx_split_variable = current_split_node.idx_split_variable |
| 145 | + split_value = current_split_node.split_value |
| 146 | + |
| 147 | + left_idx = self.X[idx_data_points, idx_split_variable] <= split_value |
| 148 | + left_node_idx_data_points = idx_data_points[left_idx] |
| 149 | + right_node_idx_data_points = idx_data_points[~left_idx] |
| 150 | + |
| 151 | + return left_node_idx_data_points, right_node_idx_data_points |
| 152 | + |
| 153 | + def get_residuals(self): |
| 154 | + """Compute the residuals.""" |
| 155 | + R_j = self.Y - self.sum_trees_output |
| 156 | + return R_j |
| 157 | + |
| 158 | + def get_residuals_loo(self, tree): |
| 159 | + """Compute the residuals without leaving the passed tree out.""" |
| 160 | + R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations)) |
| 161 | + return R_j |
| 162 | + |
| 163 | + def draw_leaf_value(self, idx_data_points): |
| 164 | + """ Draw the residual mean.""" |
| 165 | + R_j = self.get_residuals()[idx_data_points] |
| 166 | + draw = self.mean(R_j) |
| 167 | + return draw |
| 168 | + |
| 169 | + |
| 170 | +def compute_prior_probability(alpha): |
| 171 | + """ |
| 172 | + Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)). |
| 173 | + Taken from equation 19 in [Rockova2018]. |
| 174 | +
|
| 175 | + Parameters |
| 176 | + ---------- |
| 177 | + alpha : float |
| 178 | +
|
| 179 | + Returns |
| 180 | + ------- |
| 181 | + list with probabilities for leaf nodes |
| 182 | +
|
| 183 | + References |
| 184 | + ---------- |
| 185 | + .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. |
| 186 | + arXiv, `link <https://arxiv.org/abs/1810.00787>`__ |
| 187 | + """ |
| 188 | + prior_leaf_prob = [0] |
| 189 | + depth = 1 |
| 190 | + while prior_leaf_prob[-1] < 1: |
| 191 | + prior_leaf_prob.append(1 - alpha ** depth) |
| 192 | + depth += 1 |
| 193 | + return prior_leaf_prob |
| 194 | + |
| 195 | + |
| 196 | +def fast_mean(): |
| 197 | + """If available use Numba to speed up the computation of the mean.""" |
| 198 | + try: |
| 199 | + from numba import jit |
| 200 | + except ImportError: |
| 201 | + return np.mean |
| 202 | + |
| 203 | + @jit |
| 204 | + def mean(a): |
| 205 | + count = a.shape[0] |
| 206 | + suma = 0 |
| 207 | + for i in range(count): |
| 208 | + suma += a[i] |
| 209 | + return suma / count |
| 210 | + |
| 211 | + return mean |
| 212 | + |
| 213 | + |
| 214 | +def discrete_uniform_sampler(upper_value): |
| 215 | + """Draw from the uniform distribution with bounds [0, upper_value).""" |
| 216 | + return int(np.random.random() * upper_value) |
| 217 | + |
| 218 | + |
| 219 | +class BART(BaseBART): |
| 220 | + """ |
| 221 | + BART distribution. |
| 222 | +
|
| 223 | + Distribution representing a sum over trees |
| 224 | +
|
| 225 | + Parameters |
| 226 | + ---------- |
| 227 | + X : |
| 228 | + The design matrix. |
| 229 | + Y : |
| 230 | + The response vector. |
| 231 | + m : int |
| 232 | + Number of trees |
| 233 | + alpha : float |
| 234 | + Control the prior probability over the depth of the trees. Must be in the interval (0, 1), |
| 235 | + altought it is recomenned to be in the interval (0, 0.5]. |
| 236 | + """ |
| 237 | + |
| 238 | + def __init__(self, X, Y, m=200, alpha=0.25): |
| 239 | + super().__init__(X, Y, m, alpha) |
| 240 | + |
| 241 | + def _str_repr(self, name=None, dist=None, formatting="plain"): |
| 242 | + if dist is None: |
| 243 | + dist = self |
| 244 | + X = (type(self.X),) |
| 245 | + Y = (type(self.Y),) |
| 246 | + alpha = self.alpha |
| 247 | + m = self.m |
| 248 | + |
| 249 | + if formatting == "latex": |
| 250 | + return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$" |
| 251 | + else: |
| 252 | + return f"{name} ~ BART(alpha = {alpha}, m = {m})" |
0 commit comments