From f109c12ce3312641772a848d3e80cc628c752f62 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 3 Dec 2020 19:34:34 -0300 Subject: [PATCH 1/9] refactor split variables, add prior split variables, add predict function --- pymc3/distributions/bart.py | 115 +++++++++++++++++++++++++---------- pymc3/distributions/tree.py | 26 ++++++-- pymc3/step_methods/pgbart.py | 14 ++++- 3 files changed, 115 insertions(+), 40 deletions(-) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index c512d84576..25ad9b007c 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -13,6 +13,8 @@ # limitations under the License. import numpy as np +from pandas import DataFrame, Series +from scipy.stats import multinomial from .distribution import NoDistribution from .tree import Tree, SplitNode, LeafNode @@ -20,9 +22,10 @@ class BaseBART(NoDistribution): - def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs): - self.X = X - self.Y = Y + def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs): + + self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y) + super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs) if self.X.ndim != 2: @@ -47,12 +50,24 @@ def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs): self.num_observations = X.shape[0] self.num_variates = X.shape[1] + self.available_predictors = list(range(self.num_variates)) + self.ssv = sample_splitting_variable(split_prior, self.num_variates) self.m = m self.alpha = alpha self.trees = self.init_list_of_trees() + self.all_trees = [] self.mean = fast_mean() self.prior_prob_leaf_node = compute_prior_probability(alpha) + def preprocess_XY(self, X, Y): + if isinstance(Y, (Series, DataFrame)): + Y = Y.values + if isinstance(X, (Series, DataFrame)): + X = X.values + missing_data = np.any(np.isnan(X)) + X = np.random.normal(X, np.std(X, 0) / 100) + return X, Y, missing_data + def init_list_of_trees(self): initial_value_leaf_nodes = self.Y.mean() / self.m initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32") @@ -78,39 +93,26 @@ def __iter__(self): def __repr_latex(self): raise NotImplementedError - def get_available_predictors(self, idx_data_points_split_node): - possible_splitting_variables = [] - for j in range(self.num_variates): - x_j = self.X[idx_data_points_split_node, j] - x_j = x_j[~np.isnan(x_j)] - for i in range(1, len(x_j)): - if x_j[i - 1] != x_j[i]: - possible_splitting_variables.append(j) - break - return possible_splitting_variables - def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable): x_j = self.X[idx_data_points_split_node, idx_split_variable] - x_j = x_j[~np.isnan(x_j)] - values, indices = np.unique(x_j, return_index=True) - # The last value is not consider since if we choose it as the value of - # the splitting rule assignment, it would leave the right subtree empty. - return values[:-1], indices[:-1] + if self.missing_data: + x_j = x_j[~np.isnan(x_j)] + values = np.unique(x_j) + # The last value is never available as it would leave the right subtree empty. + return values[:-1] def grow_tree(self, tree, index_leaf_node): - # This can be unsuccessful when there are not available predictors current_node = tree.get_node(index_leaf_node) - available_predictors = self.get_available_predictors(current_node.idx_data_points) - - if not available_predictors: - return False, None - - index_selected_predictor = discrete_uniform_sampler(len(available_predictors)) - selected_predictor = available_predictors[index_selected_predictor] - available_splitting_rules, _ = self.get_available_splitting_rules( + index_selected_predictor = self.ssv.rvs() + selected_predictor = self.available_predictors[index_selected_predictor] + available_splitting_rules = self.get_available_splitting_rules( current_node.idx_data_points, selected_predictor ) + # This can be unsuccessful when there are not available splitting rules + if available_splitting_rules.size == 0: + return False, None + index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules)) selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule] new_split_node = SplitNode( @@ -166,6 +168,18 @@ def draw_leaf_value(self, idx_data_points): draw = self.mean(R_j) return draw + def predict(self, X_new): + """Compute out of sample predictions evaluated at X_new""" + trees = self.all_trees + num_observations = X_new.shape[0] + pred = np.zeros((len(trees), num_observations)) + for draw, trees_to_sum in enumerate(trees): + new_Y = np.zeros(X_new.shape[0]) + for tree in trees_to_sum: + new_Y += [tree.predict_out_of_sample(x) for x in X_new] + pred[draw] = new_Y + return pred + def compute_prior_probability(alpha): """ @@ -216,6 +230,39 @@ def discrete_uniform_sampler(upper_value): return int(np.random.random() * upper_value) +class sample_splitting_variable: + def __init__(self, prior, num_variates): + self.prior = prior + self.num_variates = num_variates + + if self.prior is not None: + self.prior = np.asarray(self.prior) + self.prior = self.prior / self.prior.sum() + if self.prior.size != self.num_variates: + raise ValueError( + f"The size of split_prior ({self.prior.size}) should be the " + f"same as the number of covariates ({self.num_variates})" + ) + self.enu = list(enumerate(np.cumsum(self.prior))) + + def rvs(self): + if self.prior is None: + return int(np.random.random() * self.num_variates) + else: + r = np.random.random() + for i, v in self.enu: + if r <= v: + return i + + +# def dirichlet(p): +# """Draw from a dirichlet distribution.""" + +# a = np.random.beta(0.5, 1) +# concentration = [(-(p * a) / (-1 + a)) / p] * p +# return np.random.dirichlet(alpha=concentration).argmax() + + class BART(BaseBART): """ BART distribution. @@ -224,19 +271,23 @@ class BART(BaseBART): Parameters ---------- - X : + X : array-like The design matrix. - Y : + Y : array-like The response vector. m : int Number of trees alpha : float Control the prior probability over the depth of the trees. Must be in the interval (0, 1), altought it is recomenned to be in the interval (0, 0.5]. + split_prior : array-like + Each element of split_prior should be in the [0, 1] interval and the elements should sum + to 1. Otherwise they will be normalized. + Defaults to None, all variable have the same a prior probability """ - def __init__(self, X, Y, m=200, alpha=0.25): - super().__init__(X, Y, m, alpha) + def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None): + super().__init__(X, Y, m, alpha, split_prior) def _str_repr(self, name=None, dist=None, formatting="plain"): if dist is None: diff --git a/pymc3/distributions/tree.py b/pymc3/distributions/tree.py index 3b2c098e89..06d0bf27f3 100644 --- a/pymc3/distributions/tree.py +++ b/pymc3/distributions/tree.py @@ -82,6 +82,22 @@ def predict_output(self, num_observations): output[current_node.idx_data_points] = current_node.value return output + def predict_out_of_sample(self, x): + """ + Predict output of tree for an unobserved point x. + + Parameters + ---------- + x : numpy array + + Returns + ------- + float + Value of the leaf value where the unobserved point lies. + """ + leaf_node = self._traverse_tree(x=x, node_index=0) + return leaf_node.value + def _traverse_tree(self, x, node_index=0): """ Traverse the tree starting from a particular node given an unobserved point. @@ -97,15 +113,13 @@ def _traverse_tree(self, x, node_index=0): """ current_node = self.get_node(node_index) if isinstance(current_node, SplitNode): - if x is not np.NaN: + if x[current_node.idx_split_variable] <= current_node.split_value: left_child = current_node.get_idx_left_child() - final_node = self._traverse_tree(x, left_child) + current_node = self._traverse_tree(x, left_child) else: right_child = current_node.get_idx_right_child() - final_node = self._traverse_tree(x, right_child) - else: - final_node = current_node - return final_node + current_node = self._traverse_tree(x, right_child) + return current_node def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node): """ diff --git a/pymc3/step_methods/pgbart.py b/pymc3/step_methods/pgbart.py index 1d6a503c4c..245e37e9e2 100644 --- a/pymc3/step_methods/pgbart.py +++ b/pymc3/step_methods/pgbart.py @@ -63,8 +63,13 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m self.tune = True self.idx = 0 + self.iter = 0 + self.sum_trees = [] + self.chunk = chunk + if chunk == "auto": self.chunk = max(1, int(self.bart.m * 0.1)) + self.bart.chunk = self.chunk self.num_particles = num_particles self.log_num_particles = np.log(num_particles) self.indices = list(range(1, num_particles)) @@ -95,14 +100,14 @@ def astep(self, _): self.idx = 0 for idx in range(self.idx, self.idx + self.chunk): - if idx > bart.m: + if idx >= bart.m: break self.idx += 1 tree = bart.trees[idx] R_j = bart.get_residuals_loo(tree) # Generate an initial set of SMC particles # at the end of the algorithm we return one of these particles as the new tree - particles = self.init_particles(tree.tree_id, R_j, bart.num_observations) + particles = self.init_particles(tree.tree_id, R_j, num_observations) for t in range(1, max_stages): # Get old particle at stage t @@ -146,6 +151,11 @@ def astep(self, _): bart.sum_trees_output = bart.Y - R_j + new_prediction if not self.tune: + self.iter += 1 + self.sum_trees.append(new_tree.tree) + if not self.iter % bart.m: + bart.all_trees.append(self.sum_trees) + self.sum_trees = [] for index in new_tree.used_variates: variable_inclusion[index] += 1 From db8dac899dccbd9aad54f819454cc27866aa63ac Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 7 Dec 2020 12:16:03 -0300 Subject: [PATCH 2/9] remove dirichlet --- pymc3/distributions/bart.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index 25ad9b007c..cb69d84fef 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -255,14 +255,6 @@ def rvs(self): return i -# def dirichlet(p): -# """Draw from a dirichlet distribution.""" - -# a = np.random.beta(0.5, 1) -# concentration = [(-(p * a) / (-1 + a)) / p] * p -# return np.random.dirichlet(alpha=concentration).argmax() - - class BART(BaseBART): """ BART distribution. From 58f27340e28151317e311061f617bc8abd29c669 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Dec 2020 15:21:31 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc3/distributions/bart.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index 14c58261e5..c8208b7f51 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -13,8 +13,10 @@ # limitations under the License. import numpy as np + from pandas import DataFrame, Series from scipy.stats import multinomial + from .distribution import NoDistribution from .tree import LeafNode, SplitNode, Tree From 15d46ae86f036bccef902687dde155c4511858e2 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 7 Dec 2020 12:24:16 -0300 Subject: [PATCH 4/9] remove unused import --- pymc3/distributions/bart.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index c8208b7f51..f7db9233c4 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -15,8 +15,6 @@ import numpy as np from pandas import DataFrame, Series -from scipy.stats import multinomial - from .distribution import NoDistribution from .tree import LeafNode, SplitNode, Tree From a9ca8950edfedd372f423a02bd3f55d47122a90f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Dec 2020 15:26:15 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc3/distributions/bart.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index f7db9233c4..2c27482c1d 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -15,6 +15,7 @@ import numpy as np from pandas import DataFrame, Series + from .distribution import NoDistribution from .tree import LeafNode, SplitNode, Tree From 2b539f4f8aa162621142b12152e47f98e4654cd5 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 7 Dec 2020 18:15:18 -0300 Subject: [PATCH 6/9] use to_numpy to convert a series/dataframe to an array. --- pymc3/distributions/bart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index f7db9233c4..6bb554d197 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -61,9 +61,9 @@ def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs): def preprocess_XY(self, X, Y): if isinstance(Y, (Series, DataFrame)): - Y = Y.values + Y = Y.to_numpy() if isinstance(X, (Series, DataFrame)): - X = X.values + X = X.to_numpy() missing_data = np.any(np.isnan(X)) X = np.random.normal(X, np.std(X, 0) / 100) return X, Y, missing_data From b3443847f4fecc029253d01dbc4f61404b465400 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 13 Jan 2021 08:32:47 -0300 Subject: [PATCH 7/9] use camelcase for class name --- pymc3/distributions/bart.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index e452aa8641..237ed9fba2 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -52,7 +52,7 @@ def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs): self.num_observations = X.shape[0] self.num_variates = X.shape[1] self.available_predictors = list(range(self.num_variates)) - self.ssv = sample_splitting_variable(split_prior, self.num_variates) + self.ssv = SampleSplittingVariable(split_prior, self.num_variates) self.m = m self.alpha = alpha self.trees = self.init_list_of_trees() @@ -174,6 +174,7 @@ def predict(self, X_new): trees = self.all_trees num_observations = X_new.shape[0] pred = np.zeros((len(trees), num_observations)) + np.random.randint(len(trees)) for draw, trees_to_sum in enumerate(trees): new_Y = np.zeros(X_new.shape[0]) for tree in trees_to_sum: @@ -231,7 +232,7 @@ def discrete_uniform_sampler(upper_value): return int(np.random.random() * upper_value) -class sample_splitting_variable: +class SampleSplittingVariable: def __init__(self, prior, num_variates): self.prior = prior self.num_variates = num_variates From 84b8f4028ea389473937839ca4dabb4bc798b09e Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 13 Jan 2021 08:52:32 -0300 Subject: [PATCH 8/9] fix imports --- pymc3/distributions/bart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index 3fce4755d7..8d3f36f2be 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -19,7 +19,6 @@ from pymc3.distributions.distribution import NoDistribution from pymc3.distributions.tree import LeafNode, SplitNode, Tree - __all__ = ["BART"] From 8e94ed7d28e6700d1783a9fc505ea2c7f1c926ac Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 19 Jan 2021 09:24:47 -0300 Subject: [PATCH 9/9] use already defined variable --- pymc3/distributions/bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index 8d3f36f2be..4914844555 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -176,7 +176,7 @@ def predict(self, X_new): pred = np.zeros((len(trees), num_observations)) np.random.randint(len(trees)) for draw, trees_to_sum in enumerate(trees): - new_Y = np.zeros(X_new.shape[0]) + new_Y = np.zeros(num_observations) for tree in trees_to_sum: new_Y += [tree.predict_out_of_sample(x) for x in X_new] pred[draw] = new_Y