diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index f3035fe825..30ddd0490c 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -95,7 +95,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01 - New features for BART: - Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044). - Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091). - - Modify how particle weights are computed. This improves accuracy of the modeled function (see [5177](https://github.com/pymc-devs/pymc3/pull/5177)). + - Modify PGBART sampler. Particles are not longer reweighted and the trees are reset from time to time to avoid getting trap in a local mnima. This improves accuracy of the modeled function and improves convergence (see [5223](https://github.com/pymc-devs/pymc3/pull/5223)). - `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098) - ... diff --git a/pymc/bart/pgbart.py b/pymc/bart/pgbart.py index 3bc7ac0a25..8b98515724 100644 --- a/pymc/bart/pgbart.py +++ b/pymc/bart/pgbart.py @@ -14,9 +14,6 @@ import logging -from copy import copy -from typing import Any, Dict, List, Tuple - import aesara import numpy as np @@ -25,70 +22,12 @@ from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements from pymc.bart.bart import BARTRV from pymc.bart.tree import LeafNode, SplitNode, Tree -from pymc.blocking import RaveledVars from pymc.model import modelcontext from pymc.step_methods.arraystep import ArrayStepShared, Competence _log = logging.getLogger("pymc") -class ParticleTree: - """ - Particle tree - """ - - def __init__(self, tree, log_weight, likelihood): - self.tree = tree.copy() # keeps the tree that we care at the moment - self.expansion_nodes = [0] - self.log_weight = log_weight - self.old_likelihood_logp = likelihood - self.used_variates = [] - - def sample_tree_sequential( - self, - ssv, - available_predictors, - prior_prob_leaf_node, - X, - missing_data, - sum_trees_output, - mean, - linear_fit, - m, - normal, - mu_std, - response, - ): - tree_grew = False - if self.expansion_nodes: - index_leaf_node = self.expansion_nodes.pop(0) - # Probability that this node will remain a leaf node - prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] - - if prob_leaf < np.random.random(): - tree_grew, index_selected_predictor = grow_tree( - self.tree, - index_leaf_node, - ssv, - available_predictors, - X, - missing_data, - sum_trees_output, - mean, - linear_fit, - m, - normal, - mu_std, - response, - ) - if tree_grew: - new_indexes = self.tree.idx_leaf_nodes[-2:] - self.expansion_nodes.extend(new_indexes) - self.used_variates.append(index_selected_predictor) - - return tree_grew - - class PGBART(ArrayStepShared): """ Particle Gibss BART sampling step @@ -103,21 +42,10 @@ class PGBART(ArrayStepShared): Maximum number of iterations of the conditional SMC sampler. Defaults to 100. batch : int or tuple Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees - during tuning and 20% after tuning. If a tuple is passed the first element is the batch size + during tuning and 100% after tuning. If a tuple is passed the first element is the batch size during tuning and the second the batch size after tuning. model: PyMC Model Optional model for sampling step. Defaults to None (taken from context). - - Note - ---- - This sampler is inspired by the [Lakshminarayanan2015] Particle Gibbs sampler, but introduces - several changes. The changes will be properly documented soon. - - References - ---------- - .. [Lakshminarayanan2015] Lakshminarayanan, B. and Roy, D.M. and Teh, Y. W., (2015), - Particle Gibbs for Bayesian Additive Regression Trees. - ArviX, `link `__ """ name = "bartsampler" @@ -143,7 +71,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo if self.alpha_vec is None: self.alpha_vec = np.ones(self.X.shape[1]) - self.init_mean = self.Y.mean() + self.init_mean = self.Y.mean() / self.m # if data is binary Y_unique = np.unique(self.Y) if Y_unique.size == 2 and np.all(Y_unique == [0, 1]): @@ -156,10 +84,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.num_variates = self.X.shape[1] self.available_predictors = list(range(self.num_variates)) - sum_trees_output = np.full_like(self.Y, self.init_mean).astype(aesara.config.floatX) + self.sum_trees = np.full_like(self.Y, self.init_mean * self.m).astype(aesara.config.floatX) self.a_tree = Tree.init_tree( - tree_id=0, - leaf_node_value=self.init_mean / self.m, + leaf_node_value=self.init_mean, idx_data_points=np.arange(self.num_observations, dtype="int32"), m=self.m, ) @@ -169,11 +96,11 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.normal = NormalSampler() self.prior_prob_leaf_node = compute_prior_probability(self.alpha) self.ssv = SampleSplittingVariable(self.alpha_vec) - + self.iter = 0 self.tune = True if batch == "auto": - self.batch = (max(1, int(self.m * 0.1)), max(1, int(self.m * 0.2))) + self.batch = (max(1, int(self.m * 0.1)), self.m) else: if isinstance(batch, (tuple, list)): self.batch = batch @@ -182,53 +109,54 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.log_num_particles = np.log(num_particles) self.indices = list(range(1, num_particles)) - self.len_indices = len(self.indices) + self.log_weights = np.empty(num_particles) self.max_stages = max_stages shared = make_shared_replacements(initial_values, vars, model) self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared) - self.init_likelihood = self.likelihood_logp(sum_trees_output) - self.init_log_weight = self.init_likelihood - self.log_num_particles self.all_particles = [] - for i in range(self.m): - self.a_tree.tree_id = i - self.a_tree.leaf_node_value = ( - self.init_mean / self.m + self.normal.random() * self.mu_std, - ) - p = ParticleTree( - self.a_tree, - self.init_log_weight, - self.init_likelihood, - ) + for _ in range(self.m): + p = ParticleTree(self.a_tree) self.all_particles.append(p) self.all_trees = np.array([p.tree for p in self.all_particles]) super().__init__(vars, shared) - def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: - point_map_info = q.point_map_info - sum_trees_output = q.data + def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") - tree_ids = np.random.randint(0, self.m, size=self.batch[~self.tune]) + tree_ids = np.random.choice(range(self.m), replace=False, size=self.batch[~self.tune]) + # Heuristic, reset the sampler to avoid getting stuck. + # Inspired by https://arxiv.org/abs/0806.3286 + if self.iter >= self.m * 3: + for tree_id in tree_ids: + self.sum_trees = self.sum_trees - self.all_particles[tree_id].tree.predict_output() + p = ParticleTree(self.a_tree) + self.all_particles[tree_id] = p + self.all_trees[tree_id] = p.tree + self.sum_trees = self.sum_trees + self.init_mean * self.batch[~self.tune] + self.iter = 0 + for tree_id in tree_ids: - # Generate an initial set of SMC particles + self.iter += 1 + # Generate an initial set of particles # at the end of the algorithm we return one of these particles as the new tree particles = self.init_particles(tree_id) + # update weight previous tree + self.log_weights[0] = self.likelihood_logp(self.sum_trees) # Compute the sum of trees without the tree we are attempting to replace - self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output() + self.sum_trees_noi = self.sum_trees - particles[0].tree.predict_output() - # The old tree is not growing so we update the weights only once. - self.update_weight(particles[0], new=True) - for t in range(self.max_stages): + for _ in range(self.max_stages): # Sample each particle (try to grow each tree), except for the first one. - for p in particles[1:]: - tree_grew = p.sample_tree_sequential( + stop_growing = 1 + for idx, p in enumerate(particles[1:]): + p.sample_tree( self.ssv, self.available_predictors, self.prior_prob_leaf_node, self.X, self.missing_data, - sum_trees_output, + self.sum_trees, self.mean, self.linear_fit, self.m, @@ -236,39 +164,22 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: self.mu_std, self.response, ) - if tree_grew: - self.update_weight(p) - # Normalize weights - W_t, normalized_weights = self.normalize(particles[1:]) - - # Resample all but first particle - re_n_w = normalized_weights - new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w) - particles[1:] = particles[new_indices] - - # Set the new weights - for p in particles[1:]: - p.log_weight = W_t - - # Check if particles can keep growing, otherwise stop iterating - non_available_nodes_for_expansion = [] - for p in particles[1:]: if p.expansion_nodes: - non_available_nodes_for_expansion.append(0) - if all(non_available_nodes_for_expansion): + stop_growing = 0 + else: + self.log_weights[idx + 1] = self.likelihood_logp( + self.sum_trees_noi + p.tree.predict_output() + ) + if stop_growing: break - for p in particles[1:]: - p.log_weight = p.old_likelihood_logp - - _, normalized_weights = self.normalize(particles) + normalized_weights = normalize_weights(self.log_weights) # Get the new tree and update new_particle = np.random.choice(particles, p=normalized_weights) new_tree = new_particle.tree self.all_trees[tree_id] = new_tree - new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles self.all_particles[tree_id] = new_particle - sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output() + self.sum_trees = self.sum_trees_noi + new_tree.predict_output() if self.tune: self.ssv = SampleSplittingVariable(self.alpha_vec) @@ -278,9 +189,9 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: for index in new_particle.used_variates: variable_inclusion[index] += 1 - stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)} - sum_trees_output = RaveledVars(sum_trees_output, point_map_info) - return sum_trees_output, [stats] + # TODO save trees could be optional + stats = {"variable_inclusion": variable_inclusion, "bart_trees": self.all_trees} + return self.sum_trees, [stats] @staticmethod def competence(var, has_grad): @@ -292,58 +203,82 @@ def competence(var, has_grad): return Competence.IDEAL return Competence.INCOMPATIBLE - def normalize(self, particles: List[ParticleTree]) -> Tuple[float, np.ndarray]: - """ - Use logsumexp trick to get W_t and softmax to get normalized_weights - """ - log_w = np.array([p.log_weight for p in particles]) - log_w_max = log_w.max() - log_w_ = log_w - log_w_max - w_ = np.exp(log_w_) - w_sum = w_.sum() - W_t = log_w_max + np.log(w_sum) - self.log_num_particles - normalized_weights = w_ / w_sum - # stabilize weights to avoid assigning exactly zero probability to a particle - normalized_weights += 1e-12 - - return W_t, normalized_weights - def init_particles(self, tree_id: int) -> np.ndarray: """ Initialize particles """ - p = self.all_particles[tree_id] - p.log_weight = self.init_log_weight - p.old_likelihood_logp = self.init_likelihood - particles = [p] + particles = [self.all_particles[tree_id]] for _ in self.indices: self.a_tree.tree_id = tree_id - particles.append( - ParticleTree( - self.a_tree, - self.init_log_weight, - self.init_likelihood, - ) - ) - + particles.append(ParticleTree(self.a_tree)) return np.array(particles) - def update_weight(self, particle: List[ParticleTree], new=False) -> None: - """ - Update the weight of a particle - Since the prior is used as the proposal,the weights are updated additively as the ratio of - the new and old log-likelihoods. - """ - new_likelihood = self.likelihood_logp( - self.sum_trees_output_noi + particle.tree.predict_output() - ) - if new: - particle.log_weight = new_likelihood - else: - particle.log_weight += new_likelihood - particle.old_likelihood_logp - particle.old_likelihood_logp = new_likelihood +def normalize_weights(log_weights): + """ + Use softmax to get normalized_weights + """ + log_w_max = log_weights.max() + log_w_ = log_weights - log_w_max + w_ = np.exp(log_w_) + normalized_weights = w_ / w_.sum() + # stabilize weights to avoid assigning exactly zero probability to a particle + normalized_weights += 1e-12 + + return normalized_weights + + +class ParticleTree: + """ + Particle tree + """ + + def __init__(self, tree): + self.tree = tree.copy() # keeps the tree that we care at the moment + self.expansion_nodes = [0] + self.used_variates = [] + + def sample_tree( + self, + ssv, + available_predictors, + prior_prob_leaf_node, + X, + missing_data, + sum_trees, + mean, + linear_fit, + m, + normal, + mu_std, + response, + ): + if self.expansion_nodes: + index_leaf_node = self.expansion_nodes.pop(0) + # Probability that this node will remain a leaf node + prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] + + if prob_leaf < np.random.random(): + tree_grew, index_selected_predictor = grow_tree( + self.tree, + index_leaf_node, + ssv, + available_predictors, + X, + missing_data, + sum_trees, + mean, + linear_fit, + m, + normal, + mu_std, + response, + ) + if tree_grew: + new_indexes = self.tree.idx_leaf_nodes[-2:] + self.expansion_nodes.extend(new_indexes) + self.used_variates.append(index_selected_predictor) class SampleSplittingVariable: @@ -396,7 +331,7 @@ def grow_tree( available_predictors, X, missing_data, - sum_trees_output, + sum_trees, mean, linear_fit, m, @@ -421,53 +356,48 @@ def grow_tree( idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) split_value = available_splitting_values[idx_selected_splitting_values] - new_split_node = SplitNode( - index=index_leaf_node, - idx_split_variable=selected_predictor, - split_value=split_value, - ) - left_node_idx_data_points, right_node_idx_data_points = get_new_idx_data_points( + new_idx_data_points = get_new_idx_data_points( split_value, idx_data_points, selected_predictor, X ) + current_node_children = current_node.get_idx_left_child(), current_node.get_idx_right_child() if response == "mix": response = "linear" if np.random.random() >= 0.5 else "constant" - left_node_value, left_node_linear_params = draw_leaf_value( - sum_trees_output[left_node_idx_data_points], - X[left_node_idx_data_points, selected_predictor], - mean, - linear_fit, - m, - normal, - mu_std, - response, - ) - right_node_value, right_node_linear_params = draw_leaf_value( - sum_trees_output[right_node_idx_data_points], - X[right_node_idx_data_points, selected_predictor], - mean, - linear_fit, - m, - normal, - mu_std, - response, - ) + new_nodes = [] + for idx in range(2): + idx_data_point = new_idx_data_points[idx] + node_value, node_linear_params = draw_leaf_value( + sum_trees[idx_data_point], + X[idx_data_point, selected_predictor], + mean, + linear_fit, + m, + normal, + mu_std, + response, + ) - new_left_node = LeafNode( - index=current_node.get_idx_left_child(), - value=left_node_value, - idx_data_points=left_node_idx_data_points, - linear_params=left_node_linear_params, - ) - new_right_node = LeafNode( - index=current_node.get_idx_right_child(), - value=right_node_value, - idx_data_points=right_node_idx_data_points, - linear_params=right_node_linear_params, + new_node = LeafNode( + index=current_node_children[idx], + value=node_value, + idx_data_points=idx_data_point, + linear_params=node_linear_params, + ) + new_nodes.append(new_node) + + # update tree nodes and indexes + new_split_node = SplitNode( + index=index_leaf_node, + idx_split_variable=selected_predictor, + split_value=split_value, ) - tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node) + + tree.delete_node(index_leaf_node) + tree.set_node(index_leaf_node, new_split_node) + tree.set_node(new_nodes[0].index, new_nodes[0]) + tree.set_node(new_nodes[1].index, new_nodes[1]) return True, index_selected_predictor diff --git a/pymc/bart/tree.py b/pymc/bart/tree.py index b982e80bb6..828c4fb6dc 100644 --- a/pymc/bart/tree.py +++ b/pymc/bart/tree.py @@ -38,11 +38,6 @@ class Tree: Total number of nodes. idx_leaf_nodes : list List with the index of the leaf nodes of the tree. - idx_prunable_split_nodes : list - List with the index of the prunable splitting nodes of the tree. A splitting node is - prunable if both its children are leaf nodes. - tree_id : int - Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART. num_observations : int Number of observations used to fit BART. m : int @@ -50,16 +45,13 @@ class Tree: Parameters ---------- - tree_id : int, optional num_observations : int, optional """ - def __init__(self, tree_id=0, num_observations=0, m=0): + def __init__(self, num_observations=0, m=0): self.tree_structure = {} self.num_nodes = 0 self.idx_leaf_nodes = [] - self.idx_prunable_split_nodes = [] - self.tree_id = tree_id self.num_observations = num_observations self.m = m @@ -143,39 +135,12 @@ def _traverse_tree(self, x, node_index=0, split_variable=None): current_node, split_variable = self._traverse_tree(x, right_child, split_variable) return current_node, split_variable - def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node): - """ - Grow the tree from a particular node. - - Parameters - ---------- - index_leaf_node : int - new_split_node : SplitNode - new_left_node : LeafNode - new_right_node : LeafNode - """ - current_node = self.get_node(index_leaf_node) - - self.delete_node(index_leaf_node) - self.set_node(index_leaf_node, new_split_node) - self.set_node(new_left_node.index, new_left_node) - self.set_node(new_right_node.index, new_right_node) - - # The new SplitNode is a prunable node since it has both children. - self.idx_prunable_split_nodes.append(index_leaf_node) - # If the parent of the node from which the tree is growing was a prunable node, - # remove from the list since one of its children is a SplitNode now - parent_index = current_node.get_idx_parent_node() - if parent_index in self.idx_prunable_split_nodes: - self.idx_prunable_split_nodes.remove(parent_index) - @staticmethod - def init_tree(tree_id, leaf_node_value, idx_data_points, m): + def init_tree(leaf_node_value, idx_data_points, m): """ Parameters ---------- - tree_id leaf_node_value idx_data_points m : int @@ -185,7 +150,7 @@ def init_tree(tree_id, leaf_node_value, idx_data_points, m): ------- """ - new_tree = Tree(tree_id, len(idx_data_points), m) + new_tree = Tree(len(idx_data_points), m) new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points) return new_tree