diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 06375fca..b47e912f 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -9,13 +9,6 @@ methods in the current release of PyMC experimental. :maxdepth: 2 -:mod:`pymc_experimental.bart` -============================= - -.. automodule:: pymc_experimental.bart - :members: BART, PGBART, plot_dependence, plot_variable_importance, predict - - :mod:`pymc_experimental.distributions` ============================= diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 8946a3a1..ecb0d2f5 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -12,4 +12,3 @@ from pymc_experimental import distributions, gp, utils -from pymc_experimental.bart import * diff --git a/pymc_experimental/bart/__init__.py b/pymc_experimental/bart/__init__.py deleted file mode 100644 index 4b9379fe..00000000 --- a/pymc_experimental/bart/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from pymc_experimental.bart.bart import BART -from pymc_experimental.bart.pgbart import PGBART -from pymc_experimental.bart.utils import ( - plot_dependence, - plot_variable_importance, - predict, -) - -__all__ = ["BART", "PGBART"] - - -import pymc as pm - -pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_experimental/bart/bart.py b/pymc_experimental/bart/bart.py deleted file mode 100644 index db7dcbbf..00000000 --- a/pymc_experimental/bart/bart.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import aesara.tensor as at -import numpy as np -from aeppl.logprob import _logprob -from aesara.tensor.random.op import RandomVariable -from pandas import DataFrame, Series -from pymc.distributions.distribution import NoDistribution, _moment - -__all__ = ["BART"] - - -class BARTRV(RandomVariable): - """Base class for BART.""" - - name = "BART" - ndim_supp = 1 - ndims_params = [2, 1, 0, 0, 1] - dtype = "floatX" - _print_name = ("BART", "\\operatorname{BART}") - all_trees = None - - def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): - return (self.X.shape[0],) - - @classmethod - def rng_fn(cls, rng, X, Y, m, alpha, split_prior, size): - if size is not None: - return np.full((size[0], cls.Y.shape[0]), cls.Y.mean()) - else: - return np.full(cls.Y.shape[0], cls.Y.mean()) - - -bart = BARTRV() - - -class BART(NoDistribution): - """ - Bayesian Additive Regression Tree distribution. - - Distribution representing a sum over trees - - Parameters - ---------- - X : array-like - The covariate matrix. - Y : array-like - The response vector. - m : int - Number of trees - alpha : float - Control the prior probability over the depth of the trees. Even when it can takes values in - the interval (0, 1), it is recommended to be in the interval (0, 0.5]. - split_prior : array-like - Each element of split_prior should be in the [0, 1] interval and the elements should sum to - 1. Otherwise they will be normalized. - Defaults to None, i.e. all covariates have the same prior probability to be selected. - """ - - def __new__( - cls, - name, - X, - Y, - m=50, - alpha=0.25, - split_prior=None, - **kwargs, - ): - - X, Y = preprocess_XY(X, Y) - - if split_prior is None: - split_prior = np.ones(X.shape[1]) - - bart_op = type( - f"BART_{name}", - (BARTRV,), - dict( - name="BART", - inplace=False, - initval=Y.mean(), - X=X, - Y=Y, - m=m, - alpha=alpha, - split_prior=split_prior, - ), - )() - - NoDistribution.register(BARTRV) - - @_moment.register(BARTRV) - def get_moment(rv, size, *rv_inputs): - return cls.get_moment(rv, size, *rv_inputs) - - cls.rv_op = bart_op - params = [X, Y, m, alpha, split_prior] - return super().__new__(cls, name, *params, **kwargs) - - @classmethod - def dist(cls, *params, **kwargs): - return super().dist(params, **kwargs) - - def logp(x, *inputs): - """Calculate log probability. - - Parameters - ---------- - x: numeric, TensorVariable - Value for which log-probability is calculated. - - Returns - ------- - TensorVariable - """ - return at.zeros_like(x) - - @classmethod - def get_moment(cls, rv, size, *rv_inputs): - mean = at.fill(size, rv.Y.mean()) - return mean - - -def preprocess_XY(X, Y): - if isinstance(Y, (Series, DataFrame)): - Y = Y.to_numpy() - if isinstance(X, (Series, DataFrame)): - X = X.to_numpy() - Y = Y.astype(float) - X = X.astype(float) - return X, Y - - -@_logprob.register(BARTRV) -def logp(op, value_var, *dist_params, **kwargs): - _dist_params = dist_params[3:] - value_var = value_var[0] - return BART.logp(value_var, *_dist_params) diff --git a/pymc_experimental/bart/pgbart.py b/pymc_experimental/bart/pgbart.py deleted file mode 100644 index 4f98d982..00000000 --- a/pymc_experimental/bart/pgbart.py +++ /dev/null @@ -1,594 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from copy import copy - -import aesara -import numpy as np -from aesara import function as aesara_function -from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements -from pymc.model import modelcontext -from pymc.step_methods.arraystep import ArrayStepShared, Competence - -from pymc_experimental.bart.bart import BARTRV -from pymc_experimental.bart.tree import LeafNode, SplitNode, Tree - -_log = logging.getLogger("pymc") - - -class PGBART(ArrayStepShared): - """ - Particle Gibss BART sampling step. - - Parameters - ---------- - vars: list - List of value variables for sampler - num_particles : int - Number of particles for the conditional SMC sampler. Defaults to 40 - batch : int or tuple - Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees - during tuning and after tuning. If a tuple is passed the first element is the batch size - during tuning and the second the batch size after tuning. - model: PyMC Model - Optional model for sampling step. Defaults to None (taken from context). - """ - - name = "pgbart" - default_blocked = False - generates_stats = True - stats_dtypes = [{"variable_inclusion": object, "bart_trees": object}] - - def __init__( - self, - vars=None, - num_particles=40, - batch="auto", - model=None, - ): - model = modelcontext(model) - initial_values = model.initial_point() - if vars is None: - vars = model.value_vars - else: - vars = [model.rvs_to_values.get(var, var) for var in vars] - vars = inputvars(vars) - value_bart = vars[0] - self.bart = model.values_to_rvs[value_bart].owner.op - - self.X = self.bart.X - self.Y = self.bart.Y - self.missing_data = np.any(np.isnan(self.X)) - self.m = self.bart.m - self.alpha = self.bart.alpha - shape = initial_values[value_bart.name].shape - if len(shape) == 1: - self.shape = 1 - else: - self.shape = shape[0] - - self.alpha_vec = self.bart.split_prior - self.init_mean = self.Y.mean() - # if data is binary - Y_unique = np.unique(self.Y) - if Y_unique.size == 2 and np.all(Y_unique == [0, 1]): - mu_std = 3 / self.m**0.5 - # maybe we need to check for count data - else: - mu_std = self.Y.std() / self.m**0.5 - - self.num_observations = self.X.shape[0] - self.num_variates = self.X.shape[1] - self.available_predictors = list(range(self.num_variates)) - - self.sum_trees = np.full((self.shape, self.Y.shape[0]), self.init_mean).astype( - aesara.config.floatX - ) - - self.a_tree = Tree.init_tree( - leaf_node_value=self.init_mean / self.m, - idx_data_points=np.arange(self.num_observations, dtype="int32"), - shape=self.shape, - ) - self.mean = fast_mean() - - self.normal = NormalSampler(mu_std, self.shape) - self.uniform = UniformSampler(0.33, 0.75, self.shape) - self.prior_prob_leaf_node = compute_prior_probability(self.alpha) - self.ssv = SampleSplittingVariable(self.alpha_vec) - - self.tune = True - - if batch == "auto": - batch = max(1, int(self.m * 0.1)) - self.batch = (batch, batch) - else: - if isinstance(batch, (tuple, list)): - self.batch = batch - else: - self.batch = (batch, batch) - - self.log_num_particles = np.log(num_particles) - self.indices = list(range(2, num_particles)) - self.len_indices = len(self.indices) - - shared = make_shared_replacements(initial_values, vars, model) - self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared) - self.all_particles = [] - for _ in range(self.m): - self.a_tree.leaf_node_value = self.init_mean / self.m - p = ParticleTree(self.a_tree) - self.all_particles.append(p) - self.all_trees = np.array([p.tree for p in self.all_particles]) - super().__init__(vars, shared) - - def astep(self, _): - variable_inclusion = np.zeros(self.num_variates, dtype="int") - - tree_ids = np.random.choice(range(self.m), replace=False, size=self.batch[~self.tune]) - for tree_id in tree_ids: - # Compute the sum of trees without the old tree that we are attempting to replace - self.sum_trees_noi = self.sum_trees - self.all_particles[tree_id].tree._predict() - # Generate an initial set of SMC particles - # at the end of the algorithm we return one of these particles as the new tree - particles = self.init_particles(tree_id) - - while True: - # Sample each particle (try to grow each tree), except for the first two - stop_growing = True - for p in particles[2:]: - tree_grew = p.sample_tree( - self.ssv, - self.available_predictors, - self.prior_prob_leaf_node, - self.X, - self.missing_data, - self.sum_trees, - self.mean, - self.m, - self.normal, - self.shape, - ) - if tree_grew: - self.update_weight(p) - if p.expansion_nodes: - stop_growing = False - if stop_growing: - break - - # Normalize weights - w_t, normalized_weights = self.normalize(particles[2:]) - - # Resample all but first two particles - new_indices = np.random.choice( - self.indices, size=self.len_indices, p=normalized_weights - ) - particles[2:] = particles[new_indices] - - # Set the new weight - for p in particles[2:]: - p.log_weight = w_t - - for p in particles[2:]: - p.log_weight = p.old_likelihood_logp - - _, normalized_weights = self.normalize(particles) - # Get the new tree and update - new_particle = np.random.choice(particles, p=normalized_weights) - new_tree = new_particle.tree - - new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles - self.all_particles[tree_id] = new_particle - self.sum_trees = self.sum_trees_noi + new_tree._predict() - self.all_trees[tree_id] = new_tree.trim() - - if self.tune: - self.ssv = SampleSplittingVariable(self.alpha_vec) - for index in new_particle.used_variates: - self.alpha_vec[index] += 1 - else: - for index in new_particle.used_variates: - variable_inclusion[index] += 1 - - stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)} - return self.sum_trees, [stats] - - def normalize(self, particles): - """Use logsumexp trick to get w_t and softmax to get normalized_weights. - - w_t is the un-normalized weight per particle, we will assign it to the - next round of particles, so they all start with the same weight. - """ - log_w = np.array([p.log_weight for p in particles]) - log_w_max = log_w.max() - log_w_ = log_w - log_w_max - w_ = np.exp(log_w_) - w_sum = w_.sum() - w_t = log_w_max + np.log(w_sum) - self.log_num_particles - normalized_weights = w_ / w_sum - # stabilize weights to avoid assigning exactly zero probability to a particle - normalized_weights += 1e-12 - - return w_t, normalized_weights - - def init_particles(self, tree_id: int) -> np.ndarray: - """Initialize particles.""" - p0 = self.all_particles[tree_id] - p1 = copy(p0) - p1.sample_leafs( - self.sum_trees, - self.mean, - self.m, - self.normal, - self.shape, - ) - - # The old tree and the one with new leafs do not grow so we update the weights only once - self.update_weight(p0, old=True) - self.update_weight(p1, old=True) - particles = [p0, p1] - - for _ in self.indices: - pt = ParticleTree(self.a_tree) - if self.tune: - pt.kf = self.uniform.random() - else: - pt.kf = p0.kf - particles.append(pt) - - return np.array(particles) - - def update_weight(self, particle, old=False): - """ - Update the weight of a particle. - - Since the prior is used as the proposal,the weights are updated additively as the ratio of - the new and old log-likelihoods. - """ - new_likelihood = self.likelihood_logp( - (self.sum_trees_noi + particle.tree._predict()).flatten() - ) - if old: - particle.log_weight = new_likelihood - particle.old_likelihood_logp = new_likelihood - else: - particle.log_weight += new_likelihood - particle.old_likelihood_logp - particle.old_likelihood_logp = new_likelihood - - @staticmethod - def competence(var, has_grad): - """PGBART is only suitable for BART distributions.""" - dist = getattr(var.owner, "op", None) - if isinstance(dist, BARTRV): - return Competence.IDEAL - return Competence.INCOMPATIBLE - - -class ParticleTree: - """Particle tree.""" - - def __init__(self, tree): - self.tree = tree.copy() # keeps the tree that we care at the moment - self.expansion_nodes = [0] - self.log_weight = 0 - self.old_likelihood_logp = 0 - self.used_variates = [] - self.kf = 0.75 - - def sample_tree( - self, - ssv, - available_predictors, - prior_prob_leaf_node, - X, - missing_data, - sum_trees, - mean, - m, - normal, - shape, - ): - tree_grew = False - if self.expansion_nodes: - index_leaf_node = self.expansion_nodes.pop(0) - # Probability that this node will remain a leaf node - prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] - - if prob_leaf < np.random.random(): - index_selected_predictor = grow_tree( - self.tree, - index_leaf_node, - ssv, - available_predictors, - X, - missing_data, - sum_trees, - mean, - m, - normal, - self.kf, - shape, - ) - if index_selected_predictor is not None: - new_indexes = self.tree.idx_leaf_nodes[-2:] - self.expansion_nodes.extend(new_indexes) - self.used_variates.append(index_selected_predictor) - tree_grew = True - - return tree_grew - - def sample_leafs(self, sum_trees, mean, m, normal, shape): - - for idx in self.tree.idx_leaf_nodes: - if idx > 0: - leaf = self.tree[idx] - idx_data_points = leaf.idx_data_points - node_value = draw_leaf_value( - sum_trees[:, idx_data_points], - mean, - m, - normal, - self.kf, - shape, - ) - leaf.value = node_value - - -class SampleSplittingVariable: - def __init__(self, alpha_vec): - """ - Sample splitting variables proportional to `alpha_vec`. - - This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model. - This enforce sparsity. - """ - self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) - - def rvs(self): - r = np.random.random() - for i, v in self.enu: - if r <= v: - return i - - -def compute_prior_probability(alpha): - """ - Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)). - - Taken from equation 19 in [Rockova2018]. - - Parameters - ---------- - alpha : float - - Returns - ------- - list with probabilities for leaf nodes - - References - ---------- - .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. - arXiv, `link `__ - """ - prior_leaf_prob = [0] - depth = 1 - while prior_leaf_prob[-1] < 1: - prior_leaf_prob.append(1 - alpha**depth) - depth += 1 - return prior_leaf_prob - - -def grow_tree( - tree, - index_leaf_node, - ssv, - available_predictors, - X, - missing_data, - sum_trees, - mean, - m, - normal, - kf, - shape, -): - current_node = tree.get_node(index_leaf_node) - idx_data_points = current_node.idx_data_points - - index_selected_predictor = ssv.rvs() - selected_predictor = available_predictors[index_selected_predictor] - available_splitting_values = X[idx_data_points, selected_predictor] - split_value = get_split_value(available_splitting_values, idx_data_points, missing_data) - - if split_value is not None: - - new_idx_data_points = get_new_idx_data_points( - split_value, idx_data_points, selected_predictor, X - ) - current_node_children = ( - current_node.get_idx_left_child(), - current_node.get_idx_right_child(), - ) - - new_nodes = [] - for idx in range(2): - idx_data_point = new_idx_data_points[idx] - node_value = draw_leaf_value( - sum_trees[:, idx_data_point], - mean, - m, - normal, - kf, - shape, - ) - - new_node = LeafNode( - index=current_node_children[idx], - value=node_value, - idx_data_points=idx_data_point, - ) - new_nodes.append(new_node) - - new_split_node = SplitNode( - index=index_leaf_node, - idx_split_variable=selected_predictor, - split_value=split_value, - ) - - # update tree nodes and indexes - tree.delete_leaf_node(index_leaf_node) - tree.set_node(index_leaf_node, new_split_node) - tree.set_node(new_nodes[0].index, new_nodes[0]) - tree.set_node(new_nodes[1].index, new_nodes[1]) - - return index_selected_predictor - - -def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X): - - left_idx = X[idx_data_points, selected_predictor] <= split_value - left_node_idx_data_points = idx_data_points[left_idx] - right_node_idx_data_points = idx_data_points[~left_idx] - - return left_node_idx_data_points, right_node_idx_data_points - - -def get_split_value(available_splitting_values, idx_data_points, missing_data): - - if missing_data: - idx_data_points = idx_data_points[~np.isnan(available_splitting_values)] - available_splitting_values = available_splitting_values[ - ~np.isnan(available_splitting_values) - ] - - if available_splitting_values.size > 0: - idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) - split_value = available_splitting_values[idx_selected_splitting_values] - - return split_value - - -def draw_leaf_value(Y_mu_pred, mean, m, normal, kf, shape): - """Draw Gaussian distributed leaf values.""" - if Y_mu_pred.size == 0: - return np.zeros(shape) - else: - norm = normal.random() * kf - if Y_mu_pred.size == 1: - mu_mean = np.full(shape, Y_mu_pred.item() / m) - else: - mu_mean = mean(Y_mu_pred) / m - - draw = norm + mu_mean - return draw - - -def fast_mean(): - """If available use Numba to speed up the computation of the mean.""" - try: - from numba import jit - except ImportError: - from functools import partial - - return partial(np.mean, axis=1) - - @jit - def mean(a): - if a.ndim == 1: - count = a.shape[0] - suma = 0 - for i in range(count): - suma += a[i] - return suma / count - elif a.ndim == 2: - res = np.zeros(a.shape[0]) - count = a.shape[1] - for j in range(a.shape[0]): - for i in range(count): - res[j] += a[j, i] - return res / count - - return mean - - -def discrete_uniform_sampler(upper_value): - """Draw from the uniform distribution with bounds [0, upper_value). - - This is the same and np.random.randit(upper_value) but faster. - """ - return int(np.random.random() * upper_value) - - -class NormalSampler: - """Cache samples from a standard normal distribution.""" - - def __init__(self, scale, shape): - self.size = 1000 - self.scale = scale - self.shape = shape - self.update() - - def random(self): - if self.idx == self.size: - self.update() - pop = self.cache[:, self.idx] - self.idx += 1 - return pop - - def update(self): - self.idx = 0 - self.cache = np.random.normal(loc=0.0, scale=self.scale, size=(self.shape, self.size)) - - -class UniformSampler: - """Cache samples from a uniform distribution.""" - - def __init__(self, lower_bound, upper_bound, shape): - self.size = 1000 - self.upper_bound = upper_bound - self.lower_bound = lower_bound - self.shape = shape - self.update() - - def random(self): - if self.idx == self.size: - self.update() - pop = self.cache[:, self.idx] - self.idx += 1 - return pop - - def update(self): - self.idx = 0 - self.cache = np.random.uniform( - self.lower_bound, self.upper_bound, size=(self.shape, self.size) - ) - - -def logp(point, out_vars, vars, shared): - """Compile Aesara function of the model and the input and output variables. - - Parameters - ---------- - out_vars: List - containing :class:`pymc.Distribution` for the output variables - vars: List - containing :class:`pymc.Distribution` for the input variables - shared: List - containing :class:`aesara.tensor.Tensor` for depended shared data - """ - out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared) - f = aesara_function([inarray0], out_list[0]) - f.trust_input = True - return f diff --git a/pymc_experimental/bart/tree.py b/pymc_experimental/bart/tree.py deleted file mode 100644 index d27b2d39..00000000 --- a/pymc_experimental/bart/tree.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from copy import deepcopy - -import aesara -import numpy as np - - -class Tree: - """Full binary tree. - - A full binary tree is a tree where each node has exactly zero or two children. - This structure is used as the basic component of the Bayesian Additive Regression Tree (BART) - - Attributes - ---------- - tree_structure : dict - A dictionary that represents the nodes stored in breadth-first order, based in the array - method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). - The dictionary's keys are integers that represent the nodes position. - The dictionary's values are objects of type SplitNode or LeafNode that represent the nodes - of the tree itself. - idx_leaf_nodes : list - List with the index of the leaf nodes of the tree. - num_observations : int - Number of observations used to fit BART. - m : int - Number of trees - - Parameters - ---------- - num_observations : int, optional - """ - - def __init__(self, num_observations=0, shape=1): - self.tree_structure = {} - self.idx_leaf_nodes = [] - self.output = np.zeros((num_observations, shape)).astype(aesara.config.floatX).squeeze() - - def __getitem__(self, index): - return self.get_node(index) - - def __setitem__(self, index, node): - self.set_node(index, node) - - def copy(self): - return deepcopy(self) - - def get_node(self, index): - return self.tree_structure[index] - - def set_node(self, index, node): - self.tree_structure[index] = node - if isinstance(node, LeafNode): - self.idx_leaf_nodes.append(index) - - def delete_leaf_node(self, index): - self.idx_leaf_nodes.remove(index) - del self.tree_structure[index] - - def trim(self): - a_tree = self.copy() - del a_tree.output - del a_tree.idx_leaf_nodes - for k in a_tree.tree_structure.keys(): - current_node = a_tree[k] - del current_node.depth - if isinstance(current_node, LeafNode): - del current_node.idx_data_points - return a_tree - - def _predict(self): - output = self.output - for node_index in self.idx_leaf_nodes: - leaf_node = self.get_node(node_index) - output[leaf_node.idx_data_points] = leaf_node.value - return output.T - - def predict(self, x, excluded=None): - """ - Predict output of tree for an (un)observed point x. - - Parameters - ---------- - x : numpy array - Unobserved point - - Returns - ------- - float - Value of the leaf value where the unobserved point lies. - """ - if excluded is None: - excluded = [] - node = self._traverse_tree(x, 0, excluded) - if isinstance(node, LeafNode): - leaf_value = node.value - else: - leaf_value = node - return leaf_value - - def _traverse_tree(self, x, node_index, excluded): - """ - Traverse the tree starting from a particular node given an unobserved point. - - Parameters - ---------- - x : np.ndarray - node_index : int - - Returns - ------- - LeafNode or mean of leaf node values - """ - current_node = self.get_node(node_index) - if isinstance(current_node, SplitNode): - if current_node.idx_split_variable in excluded: - leaf_values = [] - self._traverse_leaf_values(leaf_values, node_index) - return np.mean(leaf_values, 0) - - if x[current_node.idx_split_variable] <= current_node.split_value: - left_child = current_node.get_idx_left_child() - current_node = self._traverse_tree(x, left_child, excluded) - else: - right_child = current_node.get_idx_right_child() - current_node = self._traverse_tree(x, right_child, excluded) - return current_node - - def _traverse_leaf_values(self, leaf_values, node_index): - """ - Traverse the tree appending leaf values starting from a particular node. - - Parameters - ---------- - node_index : int - - Returns - ------- - List of leaf node values - """ - current_node = self.get_node(node_index) - if isinstance(current_node, SplitNode): - left_child = current_node.get_idx_left_child() - self._traverse_leaf_values(leaf_values, left_child) - right_child = current_node.get_idx_right_child() - self._traverse_leaf_values(leaf_values, right_child) - else: - leaf_values.append(current_node.value) - - @staticmethod - def init_tree(leaf_node_value, idx_data_points, shape): - """ - Initialize tree. - - Parameters - ---------- - leaf_node_value - idx_data_points - - Returns - ------- - tree - """ - new_tree = Tree(len(idx_data_points), shape) - new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points) - return new_tree - - -class BaseNode: - def __init__(self, index): - self.index = index - self.depth = int(math.floor(math.log(index + 1, 2))) - - def get_idx_parent_node(self): - return (self.index - 1) // 2 - - def get_idx_left_child(self): - return self.index * 2 + 1 - - def get_idx_right_child(self): - return self.get_idx_left_child() + 1 - - -class SplitNode(BaseNode): - def __init__(self, index, idx_split_variable, split_value): - super().__init__(index) - - self.idx_split_variable = idx_split_variable - self.split_value = split_value - - -class LeafNode(BaseNode): - def __init__(self, index, value, idx_data_points): - super().__init__(index) - self.value = value - self.idx_data_points = idx_data_points diff --git a/pymc_experimental/bart/utils.py b/pymc_experimental/bart/utils.py deleted file mode 100644 index 8be33ee4..00000000 --- a/pymc_experimental/bart/utils.py +++ /dev/null @@ -1,392 +0,0 @@ -"""Utility function for variable selection and bart interpretability.""" - -import arviz as az -import matplotlib.pyplot as plt -import numpy as np -from numpy.random import RandomState -from scipy.interpolate import griddata -from scipy.signal import savgol_filter -from scipy.stats import pearsonr - - -def predict(idata, rng, X, size=None, excluded=None): - """ - Generate samples from the BART-posterior. - - Parameters - ---------- - idata : InferenceData - InferenceData containing a collection of BART_trees in sample_stats group - rng: NumPy random generator - X : array-like - A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for - out-of-sample predictions. - size : int or tuple - Number of samples. - excluded : list - indexes of the variables to exclude when computing predictions - """ - bart_trees = idata.sample_stats.bart_trees - stacked_trees = bart_trees.stack(trees=["chain", "draw"]) - if size is None: - size = () - elif isinstance(size, int): - size = [size] - - flatten_size = 1 - for s in size: - flatten_size *= s - - idx = rng.randint(len(stacked_trees.trees), size=flatten_size) - shape = stacked_trees.isel(trees=0).values[0].predict(X[0]).size - - pred = np.zeros((flatten_size, X.shape[0], shape)) - - for ind, p in enumerate(pred): - for tree in stacked_trees.isel(trees=idx[ind]).values: - p += np.array([tree.predict(x, excluded) for x in X]) - pred.reshape((*size, shape, -1)) - return pred - - -def plot_dependence( - idata, - X, - Y=None, - kind="pdp", - xs_interval="linear", - xs_values=None, - var_idx=None, - var_discrete=None, - samples=50, - instances=10, - random_seed=None, - sharey=True, - rug=True, - smooth=True, - indices=None, - grid="long", - color="C0", - color_mean="C0", - alpha=0.1, - figsize=None, - smooth_kwargs=None, - ax=None, -): - """ - Partial dependence or individual conditional expectation plot. - - Parameters - ---------- - idata: InferenceData - InferenceData containing a collection of BART_trees in sample_stats group - X : array-like - The covariate matrix. - Y : array-like - The response vector. - kind : str - Whether to plor a partial dependence plot ("pdp") or an individual conditional expectation - plot ("ice"). Defaults to pdp. - xs_interval : str - Method used to compute the values X used to evaluate the predicted function. "linear", - evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified - quantiles of X. "insample", the evaluation is done at the values of X. - For discrete variables these options are ommited. - xs_values : int or list - Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of - points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of - quantiles to compute, which must be between 0 and 1 inclusive. - Ignored when ``xs_interval="insample"``. - var_idx : list - List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : list - List of the indices of the covariate treated as discrete. - samples : int - Number of posterior samples used in the predictions. Defaults to 50 - instances : int - Number of instances of X to plot. Only relevant if ice ``kind="ice"`` plots. - random_seed : int - Seed used to sample from the posterior. Defaults to None. - sharey : bool - Controls sharing of properties among y-axes. Defaults to True. - rug : bool - Whether to include a rugplot. Defaults to True. - smooth : bool - If True the result will be smoothed by first computing a linear interpolation of the data - over a regular grid and then applying the Savitzky-Golay filter to the interpolated data. - Defaults to True. - grid : str or tuple - How to arrange the subplots. Defaults to "long", one subplot below the other. - Other options are "wide", one subplot next to eachother or a tuple indicating the number of - rows and columns. - color : matplotlib valid color - Color used to plot the pdp or ice. Defaults to "C0" - color_mean : matplotlib valid color - Color used to plot the mean pdp or ice. Defaults to "C0", - alpha : float - Transparency level, should in the interval [0, 1]. - figsize : tuple - Figure size. If None it will be defined automatically. - smooth_kwargs : dict - Additional keywords modifying the Savitzky-Golay filter. - See scipy.signal.savgol_filter() for details. - ax : axes - Matplotlib axes. - - Returns - ------- - axes: matplotlib axes - """ - if kind not in ["pdp", "ice"]: - raise ValueError(f"kind={kind} is not suported. Available option are 'pdp' or 'ice'") - - if xs_interval not in ["insample", "linear", "quantiles"]: - raise ValueError( - f"""{xs_interval} is not suported. - Available option are 'insample', 'linear' or 'quantiles'""" - ) - - rng = RandomState(seed=random_seed) - - if hasattr(X, "columns") and hasattr(X, "values"): - X_names = list(X.columns) - X = X.values - else: - X_names = [] - - if hasattr(Y, "name"): - Y_label = f"Predicted {Y.name}" - else: - Y_label = "Predicted Y" - - num_covariates = X.shape[1] - - indices = list(range(num_covariates)) - - if var_idx is None: - var_idx = indices - if var_discrete is None: - var_discrete = [] - - if X_names: - X_labels = [X_names[idx] for idx in var_idx] - else: - X_labels = [f"X_{idx}" for idx in var_idx] - - if xs_interval == "linear" and xs_values is None: - xs_values = 10 - - if xs_interval == "quantiles" and xs_values is None: - xs_values = [0.05, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.95] - - if kind == "ice": - instances = np.random.choice(range(X.shape[0]), replace=False, size=instances) - - new_Y = [] - new_X_target = [] - y_mins = [] - - new_X = np.zeros_like(X) - idx_s = list(range(X.shape[0])) - for i in var_idx: - indices_mi = indices[:] - indices_mi.pop(i) - y_pred = [] - if kind == "pdp": - if i in var_discrete: - new_X_i = np.unique(X[:, i]) - else: - if xs_interval == "linear": - new_X_i = np.linspace(np.nanmin(X[:, i]), np.nanmax(X[:, i]), xs_values) - elif xs_interval == "quantiles": - new_X_i = np.quantile(X[:, i], q=xs_values) - elif xs_interval == "insample": - new_X_i = X[:, i] - - for x_i in new_X_i: - new_X[:, indices_mi] = X[:, indices_mi] - new_X[:, i] = x_i - y_pred.append(np.mean(predict(idata, rng, X=new_X, size=samples), 1)) - new_X_target.append(new_X_i) - else: - for instance in instances: - new_X = X[idx_s] - new_X[:, indices_mi] = X[:, indices_mi][instance] - y_pred.append(np.mean(predict(idata, rng, X=new_X, size=samples), 0)) - new_X_target.append(new_X[:, i]) - y_mins.append(np.min(y_pred)) - new_Y.append(np.array(y_pred).T) - - shape = 1 - if new_Y[0].ndim == 3: - shape = new_Y[0].shape[0] - if ax is None: - if grid == "long": - fig, axes = plt.subplots(len(var_idx) * shape, sharey=sharey, figsize=figsize) - elif grid == "wide": - fig, axes = plt.subplots(1, len(var_idx) * shape, sharey=sharey, figsize=figsize) - elif isinstance(grid, tuple): - fig, axes = plt.subplots(grid[0], grid[1], sharey=sharey, figsize=figsize) - axes = np.ravel(axes) - else: - axes = [ax] - fig = ax.get_figure() - - x_idx = 0 - y_idx = 0 - for ax in axes: - if x_idx >= len(var_idx): - ax.set_axis_off() - fig.delaxes(ax) - - nyi = new_Y[x_idx][y_idx] - nxi = new_X_target[x_idx] - var = var_idx[x_idx] - - ax.set_xlabel(X_labels[x_idx]) - x_idx += 1 - if x_idx == len(var_idx): - x_idx = 0 - y_idx += 1 - - if var in var_discrete: - if kind == "pdp": - y_means = nyi.mean(0) - hdi = az.hdi(nyi) - ax.errorbar( - nxi, - y_means, - (y_means - hdi[:, 0], hdi[:, 1] - y_means), - fmt=".", - color=color, - ) - else: - ax.plot(nxi, nyi, ".", color=color, alpha=alpha) - ax.plot(nxi, nyi.mean(1), "o", color=color_mean) - ax.set_xticks(nxi) - elif smooth: - if smooth_kwargs is None: - smooth_kwargs = {} - smooth_kwargs.setdefault("window_length", 55) - smooth_kwargs.setdefault("polyorder", 2) - x_data = np.linspace(np.nanmin(nxi), np.nanmax(nxi), 200) - x_data[0] = (x_data[0] + x_data[1]) / 2 - if kind == "pdp": - interp = griddata(nxi, nyi.mean(0), x_data) - else: - interp = griddata(nxi, nyi, x_data) - - y_data = savgol_filter(interp, axis=0, **smooth_kwargs) - - if kind == "pdp": - az.plot_hdi(nxi, nyi, color=color, fill_kwargs={"alpha": alpha}, ax=ax) - ax.plot(x_data, y_data, color=color_mean) - else: - ax.plot(x_data, y_data.mean(1), color=color_mean) - ax.plot(x_data, y_data, color=color, alpha=alpha) - - else: - idx = np.argsort(nxi) - if kind == "pdp": - az.plot_hdi( - nxi, - nyi, - smooth=smooth, - fill_kwargs={"alpha": alpha}, - ax=ax, - ) - ax.plot(nxi[idx], nyi[idx].mean(0), color=color) - else: - ax.plot(nxi[idx], nyi[idx], color=color, alpha=alpha) - ax.plot(nxi[idx], nyi[idx].mean(1), color=color_mean) - - if rug: - lb = np.min(y_mins) - ax.plot(X[:, var], np.full_like(X[:, var], lb), "k|") - - fig.text(-0.05, 0.5, Y_label, va="center", rotation="vertical", fontsize=15) - return axes - - -def plot_variable_importance( - idata, X, labels=None, sort_vars=True, figsize=None, samples=100, random_seed=None -): - """ - Estimates variable importance from the BART-posterior. - - Parameters - ---------- - idata: InferenceData - InferenceData containing a collection of BART_trees in sample_stats group - X : array-like - The covariate matrix. - labels : list - List of the names of the covariates. If X is a DataFrame the names of the covariables will - be taken from it and this argument will be ignored. - sort_vars : bool - Whether to sort the variables according to their variable importance. Defaults to True. - figsize : tuple - Figure size. If None it will be defined automatically. - samples : int - Number of predictions used to compute correlation for subsets of variables. Defaults to 100 - random_seed : int - random_seed used to sample from the posterior. Defaults to None. - Returns - ------- - idxs: indexes of the covariates from higher to lower relative importance - axes: matplotlib axes - """ - rng = RandomState(seed=random_seed) - _, axes = plt.subplots(2, 1, figsize=figsize) - - if hasattr(X, "columns") and hasattr(X, "values"): - labels = X.columns - X = X.values - - VI = idata.sample_stats["variable_inclusion"].mean(("chain", "draw")).values - if labels is None: - labels = np.arange(len(VI)) - else: - labels = np.array(labels) - - ticks = np.arange(len(VI), dtype=int) - idxs = np.argsort(VI) - subsets = [idxs[:-i] for i in range(1, len(idxs))] - subsets.append(None) - - if sort_vars: - indices = idxs[::-1] - else: - indices = np.arange(len(VI)) - axes[0].plot((VI / VI.sum())[indices], "o-") - axes[0].set_xticks(ticks) - axes[0].set_xticklabels(labels[indices]) - axes[0].set_xlabel("covariables") - axes[0].set_ylabel("importance") - - predicted_all = predict(idata, rng, X=X, size=samples, excluded=None) - - EV_mean = np.zeros(len(VI)) - EV_hdi = np.zeros((len(VI), 2)) - for idx, subset in enumerate(subsets): - predicted_subset = predict(idata, rng, X=X, size=samples, excluded=subset) - pearson = np.zeros(samples) - for j in range(samples): - pearson[j] = ( - pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] - ) ** 2 - EV_mean[idx] = np.mean(pearson) - EV_hdi[idx] = az.hdi(pearson) - - axes[1].errorbar(ticks, EV_mean, np.array((EV_mean - EV_hdi[:, 0], EV_hdi[:, 1] - EV_mean))) - - axes[1].set_xticks(ticks) - axes[1].set_xticklabels(ticks + 1) - axes[1].set_xlabel("number of covariables") - axes[1].set_ylabel("R²", rotation=0, labelpad=12) - axes[1].set_ylim(0, 1) - - axes[0].set_xlim(-0.5, len(VI) - 0.5) - axes[1].set_xlim(-0.5, len(VI) - 0.5) - - return idxs[::-1], axes diff --git a/pymc_experimental/tests/test_bart.py b/pymc_experimental/tests/test_bart.py deleted file mode 100644 index 3148b071..00000000 --- a/pymc_experimental/tests/test_bart.py +++ /dev/null @@ -1,139 +0,0 @@ -import numpy as np -import pymc as pm -import pytest -from numpy.random import RandomState -from numpy.testing import assert_almost_equal, assert_array_equal -from pymc.tests.distributions.util import assert_moment_is_expected - -import pymc_experimental as pmx - - -def test_split_node(): - split_node = pmx.bart.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0) - assert split_node.index == 5 - assert split_node.idx_split_variable == 2 - assert split_node.split_value == 3.0 - assert split_node.depth == 2 - assert split_node.get_idx_parent_node() == 2 - assert split_node.get_idx_left_child() == 11 - assert split_node.get_idx_right_child() == 12 - - -def test_leaf_node(): - leaf_node = pmx.bart.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3]) - assert leaf_node.index == 5 - assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3]) - assert leaf_node.value == 3.14 - assert leaf_node.get_idx_parent_node() == 2 - assert leaf_node.get_idx_left_child() == 11 - assert leaf_node.get_idx_right_child() == 12 - - -def test_bart_vi(): - X = np.random.normal(0, 1, size=(250, 3)) - Y = np.random.normal(0, 1, size=250) - X[:, 0] = np.random.normal(Y, 0.1) - - with pm.Model() as model: - mu = pmx.BART("mu", X, Y, m=10) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) - var_imp = ( - idata.sample_stats["variable_inclusion"] - .stack(samples=("chain", "draw")) - .mean("samples") - ) - var_imp /= var_imp.sum() - assert var_imp[0] > var_imp[1:].sum() - assert_almost_equal(var_imp.sum(), 1) - - -def test_missing_data(): - X = np.random.normal(0, 1, size=(50, 2)) - Y = np.random.normal(0, 1, size=50) - X[10:20, 0] = np.nan - - with pm.Model() as model: - mu = pmx.BART("mu", X, Y, m=10) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(tune=10, draws=10, chains=1, random_seed=3415) - - -class TestUtils: - X_norm = np.random.normal(0, 1, size=(50, 2)) - X_binom = np.random.binomial(1, 0.5, size=(50, 1)) - X = np.hstack([X_norm, X_binom]) - Y = np.random.normal(0, 1, size=50) - - with pm.Model() as model: - mu = pmx.BART("mu", X, Y, m=10) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) - - def test_predict(self): - rng = RandomState(12345) - pred_all = pmx.bart.utils.predict(self.idata, rng, X=self.X, size=2) - rng = RandomState(12345) - pred_first = pmx.bart.utils.predict(self.idata, rng, X=self.X[:10]) - - assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4) - assert pred_all.shape == (2, 50, 1) - assert pred_first.shape == (1, 10, 1) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - { - "kind": "pdp", - "samples": 2, - "xs_interval": "quantiles", - "xs_values": [0.25, 0.5, 0.75], - "var_discrete": [3], - }, - {"kind": "ice", "instances": 2}, - {"var_idx": [0], "rug": False, "smooth": False, "color": "k"}, - {"grid": (1, 2), "sharey": "none", "alpha": 1}, - ], - ) - def test_pdp(self, kwargs): - pmx.bart.utils.plot_dependence(self.idata, X=self.X, Y=self.Y, **kwargs) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - {"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)}, - ], - ) - def test_vi(self, kwargs): - pmx.bart.utils.plot_variable_importance(self.idata, X=self.X, **kwargs) - - def test_pdp_pandas_labels(self): - pd = pytest.importorskip("pandas") - - X_names = ["norm1", "norm2", "binom"] - X_pd = pd.DataFrame(self.X, columns=X_names) - Y_pd = pd.Series(self.Y, name="response") - axes = pmx.bart.utils.plot_dependence(self.idata, X=X_pd, Y=Y_pd) - - figure = axes[0].figure - assert figure.texts[0].get_text() == "Predicted response" - assert_array_equal([ax.get_xlabel() for ax in axes], X_names) - - -@pytest.mark.parametrize( - "size, expected", - [ - (None, np.zeros(50)), - ], -) -def test_bart_moment(size, expected): - X = np.zeros((50, 2)) - Y = np.zeros(50) - with pm.Model() as model: - pmx.BART("x", X=X, Y=Y, size=size) - assert_moment_is_expected(model, expected)