Skip to content

Commit 47e2e2b

Browse files
authored
BART-VI use mean of leaf nodes when prunning instead of zero (#54)
* use mean when prunning instead of zero * fix labels * use r2, improve labels
1 parent 415f896 commit 47e2e2b

File tree

2 files changed

+64
-28
lines changed

2 files changed

+64
-28
lines changed

pymc_experimental/bart/tree.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,7 @@ class Tree:
4949
def __init__(self, num_observations=0, shape=1):
5050
self.tree_structure = {}
5151
self.idx_leaf_nodes = []
52-
self.shape = shape
53-
self.output = (
54-
np.zeros((num_observations, self.shape)).astype(aesara.config.floatX).squeeze()
55-
)
52+
self.output = np.zeros((num_observations, shape)).astype(aesara.config.floatX).squeeze()
5653

5754
def __getitem__(self, index):
5855
return self.get_node(index)
@@ -93,29 +90,30 @@ def _predict(self):
9390
output[leaf_node.idx_data_points] = leaf_node.value
9491
return output.T
9592

96-
def predict(self, X, excluded=None):
93+
def predict(self, x, excluded=None):
9794
"""
98-
Predict output of tree for an (un)observed point X.
95+
Predict output of tree for an (un)observed point x.
9996
10097
Parameters
10198
----------
102-
X : numpy array
99+
x : numpy array
103100
Unobserved point
104101
105102
Returns
106103
-------
107104
float
108105
Value of the leaf value where the unobserved point lies.
109106
"""
110-
leaf_node = self._traverse_tree(X, node_index=0)
111-
leaf_value = leaf_node.value
112-
if excluded is not None:
113-
parent_node = leaf_node.get_idx_parent_node()
114-
if self.get_node(parent_node).idx_split_variable in excluded:
115-
leaf_value = np.zeros(self.shape)
107+
if excluded is None:
108+
excluded = []
109+
node = self._traverse_tree(x, 0, excluded)
110+
if isinstance(node, LeafNode):
111+
leaf_value = node.value
112+
else:
113+
leaf_value = node
116114
return leaf_value
117115

118-
def _traverse_tree(self, x, node_index=0):
116+
def _traverse_tree(self, x, node_index, excluded):
119117
"""
120118
Traverse the tree starting from a particular node given an unobserved point.
121119
@@ -126,18 +124,44 @@ def _traverse_tree(self, x, node_index=0):
126124
127125
Returns
128126
-------
129-
LeafNode
127+
LeafNode or mean of leaf node values
130128
"""
131129
current_node = self.get_node(node_index)
132130
if isinstance(current_node, SplitNode):
131+
if current_node.idx_split_variable in excluded:
132+
leaf_values = []
133+
self._traverse_leaf_values(leaf_values, node_index)
134+
return np.mean(leaf_values, 0)
135+
133136
if x[current_node.idx_split_variable] <= current_node.split_value:
134137
left_child = current_node.get_idx_left_child()
135-
current_node = self._traverse_tree(x, left_child)
138+
current_node = self._traverse_tree(x, left_child, excluded)
136139
else:
137140
right_child = current_node.get_idx_right_child()
138-
current_node = self._traverse_tree(x, right_child)
141+
current_node = self._traverse_tree(x, right_child, excluded)
139142
return current_node
140143

144+
def _traverse_leaf_values(self, leaf_values, node_index):
145+
"""
146+
Traverse the tree appending leaf values starting from a particular node.
147+
148+
Parameters
149+
----------
150+
node_index : int
151+
152+
Returns
153+
-------
154+
List of leaf node values
155+
"""
156+
current_node = self.get_node(node_index)
157+
if isinstance(current_node, SplitNode):
158+
left_child = current_node.get_idx_left_child()
159+
self._traverse_leaf_values(leaf_values, left_child)
160+
right_child = current_node.get_idx_right_child()
161+
self._traverse_leaf_values(leaf_values, right_child)
162+
else:
163+
leaf_values.append(current_node.value)
164+
141165
@staticmethod
142166
def init_tree(leaf_node_value, idx_data_points, shape):
143167
"""

pymc_experimental/bart/utils.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,9 @@ def plot_dependence(
309309
return axes
310310

311311

312-
def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, random_seed=None):
312+
def plot_variable_importance(
313+
idata, X, labels=None, sort_vars=True, figsize=None, samples=100, random_seed=None
314+
):
313315
"""
314316
Estimates variable importance from the BART-posterior.
315317
@@ -319,9 +321,11 @@ def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, r
319321
InferenceData containing a collection of BART_trees in sample_stats group
320322
X : array-like
321323
The covariate matrix.
322-
labels: list
324+
labels : list
323325
List of the names of the covariates. If X is a DataFrame the names of the covariables will
324326
be taken from it and this argument will be ignored.
327+
sort_vars : bool
328+
Whether to sort the variables according to their variable importance. Defaults to True.
325329
figsize : tuple
326330
Figure size. If None it will be defined automatically.
327331
samples : int
@@ -337,23 +341,29 @@ def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, r
337341
_, axes = plt.subplots(2, 1, figsize=figsize)
338342

339343
if hasattr(X, "columns") and hasattr(X, "values"):
340-
labels = list(X.columns)
344+
labels = X.columns
341345
X = X.values
342346

343347
VI = idata.sample_stats["variable_inclusion"].mean(("chain", "draw")).values
344348
if labels is None:
345-
labels = range(len(VI))
349+
labels = np.arange(len(VI))
350+
else:
351+
labels = np.array(labels)
346352

347353
ticks = np.arange(len(VI), dtype=int)
348354
idxs = np.argsort(VI)
349355
subsets = [idxs[:-i] for i in range(1, len(idxs))]
350356
subsets.append(None)
351357

352-
axes[0].plot(VI / VI.sum(), "o-")
358+
if sort_vars:
359+
indices = idxs[::-1]
360+
else:
361+
indices = np.arange(len(VI))
362+
axes[0].plot((VI / VI.sum())[indices], "o-")
353363
axes[0].set_xticks(ticks)
354-
axes[0].set_xticklabels(labels)
355-
axes[0].set_xlabel("variable index")
356-
axes[0].set_ylabel("relative importance")
364+
axes[0].set_xticklabels(labels[indices])
365+
axes[0].set_xlabel("covariables")
366+
axes[0].set_ylabel("importance")
357367

358368
predicted_all = predict(idata, rng, X=X, size=samples, excluded=None)
359369

@@ -363,16 +373,18 @@ def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, r
363373
predicted_subset = predict(idata, rng, X=X, size=samples, excluded=subset)
364374
pearson = np.zeros(samples)
365375
for j in range(samples):
366-
pearson[j] = pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
376+
pearson[j] = (
377+
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
378+
) ** 2
367379
EV_mean[idx] = np.mean(pearson)
368380
EV_hdi[idx] = az.hdi(pearson)
369381

370382
axes[1].errorbar(ticks, EV_mean, np.array((EV_mean - EV_hdi[:, 0], EV_hdi[:, 1] - EV_mean)))
371383

372384
axes[1].set_xticks(ticks)
373385
axes[1].set_xticklabels(ticks + 1)
374-
axes[1].set_xlabel("number of components")
375-
axes[1].set_ylabel("correlation")
386+
axes[1].set_xlabel("number of covariables")
387+
axes[1].set_ylabel("R²", rotation=0, labelpad=12)
376388
axes[1].set_ylim(0, 1)
377389

378390
axes[0].set_xlim(-0.5, len(VI) - 0.5)

0 commit comments

Comments
 (0)