Skip to content

Commit fe4b769

Browse files
committed
tidy up code
1 parent 0a96f23 commit fe4b769

File tree

2 files changed

+69
-88
lines changed

2 files changed

+69
-88
lines changed

pymc/bart/pgbart.py

Lines changed: 66 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -28,58 +28,6 @@
2828
_log = logging.getLogger("pymc")
2929

3030

31-
class ParticleTree:
32-
"""
33-
Particle tree
34-
"""
35-
36-
def __init__(self, tree):
37-
self.tree = tree.copy() # keeps the tree that we care at the moment
38-
self.expansion_nodes = [0]
39-
self.used_variates = []
40-
41-
def sample_tree(
42-
self,
43-
ssv,
44-
available_predictors,
45-
prior_prob_leaf_node,
46-
X,
47-
missing_data,
48-
sum_trees,
49-
mean,
50-
linear_fit,
51-
m,
52-
normal,
53-
mu_std,
54-
response,
55-
):
56-
if self.expansion_nodes:
57-
index_leaf_node = self.expansion_nodes.pop(0)
58-
# Probability that this node will remain a leaf node
59-
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]
60-
61-
if prob_leaf < np.random.random():
62-
tree_grew, index_selected_predictor = grow_tree(
63-
self.tree,
64-
index_leaf_node,
65-
ssv,
66-
available_predictors,
67-
X,
68-
missing_data,
69-
sum_trees,
70-
mean,
71-
linear_fit,
72-
m,
73-
normal,
74-
mu_std,
75-
response,
76-
)
77-
if tree_grew:
78-
new_indexes = self.tree.idx_leaf_nodes[-2:]
79-
self.expansion_nodes.extend(new_indexes)
80-
self.used_variates.append(index_selected_predictor)
81-
82-
8331
class PGBART(ArrayStepShared):
8432
"""
8533
Particle Gibss BART sampling step
@@ -138,7 +86,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
13886

13987
self.sum_trees = np.full_like(self.Y, self.init_mean * self.m).astype(aesara.config.floatX)
14088
self.a_tree = Tree.init_tree(
141-
tree_id=0,
14289
leaf_node_value=self.init_mean,
14390
idx_data_points=np.arange(self.num_observations, dtype="int32"),
14491
m=self.m,
@@ -162,14 +109,13 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
162109

163110
self.log_num_particles = np.log(num_particles)
164111
self.indices = list(range(1, num_particles))
165-
self.log_likelihoods = np.empty(num_particles)
112+
self.log_weights = np.empty(num_particles)
166113
self.max_stages = max_stages
167114

168115
shared = make_shared_replacements(initial_values, vars, model)
169116
self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared)
170117
self.all_particles = []
171-
for i in range(self.m):
172-
self.a_tree.tree_id = i
118+
for _ in range(self.m):
173119
p = ParticleTree(self.a_tree)
174120
self.all_particles.append(p)
175121
self.all_trees = np.array([p.tree for p in self.all_particles])
@@ -184,7 +130,6 @@ def astep(self, _):
184130
if self.iter >= self.m * 3:
185131
for tree_id in tree_ids:
186132
self.sum_trees = self.sum_trees - self.all_particles[tree_id].tree.predict_output()
187-
self.a_tree.tree_id = tree_id
188133
p = ParticleTree(self.a_tree)
189134
self.all_particles[tree_id] = p
190135
self.all_trees[tree_id] = p.tree
@@ -196,13 +141,15 @@ def astep(self, _):
196141
# Generate an initial set of particles
197142
# at the end of the algorithm we return one of these particles as the new tree
198143
particles = self.init_particles(tree_id)
144+
# update weight previous tree
145+
self.log_weights[0] = self.likelihood_logp(self.sum_trees)
199146
# Compute the sum of trees without the tree we are attempting to replace
200147
self.sum_trees_noi = self.sum_trees - particles[0].tree.predict_output()
201148

202149
for _ in range(self.max_stages):
203150
# Sample each particle (try to grow each tree), except for the first one.
204151
stop_growing = 1
205-
for p in particles[1:]:
152+
for idx, p in enumerate(particles[1:]):
206153
p.sample_tree(
207154
self.ssv,
208155
self.available_predictors,
@@ -219,22 +166,20 @@ def astep(self, _):
219166
)
220167
if p.expansion_nodes:
221168
stop_growing = 0
169+
else:
170+
self.log_weights[idx + 1] = self.likelihood_logp(
171+
self.sum_trees_noi + p.tree.predict_output()
172+
)
222173
if stop_growing:
223174
break
224175

225-
for idx, p in enumerate(particles):
226-
self.log_likelihoods[idx] = self.likelihood_logp(
227-
self.sum_trees_noi + p.tree.predict_output()
228-
)
229-
230-
normalized_weights = normalize(self.log_likelihoods)
176+
normalized_weights = normalize_weights(self.log_weights)
231177
# Get the new tree and update
232178
new_particle = np.random.choice(particles, p=normalized_weights)
233179
new_tree = new_particle.tree
234180
self.all_trees[tree_id] = new_tree
235181
self.all_particles[tree_id] = new_particle
236-
new_pred = new_tree.predict_output()
237-
self.sum_trees = self.sum_trees_noi + new_pred
182+
self.sum_trees = self.sum_trees_noi + new_tree.predict_output()
238183

239184
if self.tune:
240185
self.ssv = SampleSplittingVariable(self.alpha_vec)
@@ -270,12 +215,12 @@ def init_particles(self, tree_id: int) -> np.ndarray:
270215
return np.array(particles)
271216

272217

273-
def normalize(log_likelihoods):
218+
def normalize_weights(log_weights):
274219
"""
275220
Use softmax to get normalized_weights
276221
"""
277-
log_w_max = log_likelihoods.max()
278-
log_w_ = log_likelihoods - log_w_max
222+
log_w_max = log_weights.max()
223+
log_w_ = log_weights - log_w_max
279224
w_ = np.exp(log_w_)
280225
normalized_weights = w_ / w_.sum()
281226
# stabilize weights to avoid assigning exactly zero probability to a particle
@@ -284,6 +229,58 @@ def normalize(log_likelihoods):
284229
return normalized_weights
285230

286231

232+
class ParticleTree:
233+
"""
234+
Particle tree
235+
"""
236+
237+
def __init__(self, tree):
238+
self.tree = tree.copy() # keeps the tree that we care at the moment
239+
self.expansion_nodes = [0]
240+
self.used_variates = []
241+
242+
def sample_tree(
243+
self,
244+
ssv,
245+
available_predictors,
246+
prior_prob_leaf_node,
247+
X,
248+
missing_data,
249+
sum_trees,
250+
mean,
251+
linear_fit,
252+
m,
253+
normal,
254+
mu_std,
255+
response,
256+
):
257+
if self.expansion_nodes:
258+
index_leaf_node = self.expansion_nodes.pop(0)
259+
# Probability that this node will remain a leaf node
260+
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]
261+
262+
if prob_leaf < np.random.random():
263+
tree_grew, index_selected_predictor = grow_tree(
264+
self.tree,
265+
index_leaf_node,
266+
ssv,
267+
available_predictors,
268+
X,
269+
missing_data,
270+
sum_trees,
271+
mean,
272+
linear_fit,
273+
m,
274+
normal,
275+
mu_std,
276+
response,
277+
)
278+
if tree_grew:
279+
new_indexes = self.tree.idx_leaf_nodes[-2:]
280+
self.expansion_nodes.extend(new_indexes)
281+
self.used_variates.append(index_selected_predictor)
282+
283+
287284
class SampleSplittingVariable:
288285
def __init__(self, alpha_vec):
289286
"""
@@ -401,13 +398,6 @@ def grow_tree(
401398
tree.set_node(index_leaf_node, new_split_node)
402399
tree.set_node(new_nodes[0].index, new_nodes[0])
403400
tree.set_node(new_nodes[1].index, new_nodes[1])
404-
# The new SplitNode is a prunable node since it has both children.
405-
tree.idx_prunable_split_nodes.append(index_leaf_node)
406-
# If the parent of the node from which the tree is growing was a prunable node,
407-
# remove from the list since one of its children is a SplitNode now
408-
parent_index = current_node.get_idx_parent_node()
409-
if parent_index in tree.idx_prunable_split_nodes:
410-
tree.idx_prunable_split_nodes.remove(parent_index)
411401

412402
return True, index_selected_predictor
413403

pymc/bart/tree.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,20 @@ class Tree:
3838
Total number of nodes.
3939
idx_leaf_nodes : list
4040
List with the index of the leaf nodes of the tree.
41-
idx_prunable_split_nodes : list
42-
List with the index of the prunable splitting nodes of the tree. A splitting node is
43-
prunable if both its children are leaf nodes.
44-
tree_id : int
45-
Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART.
4641
num_observations : int
4742
Number of observations used to fit BART.
4843
m : int
4944
Number of trees
5045
5146
Parameters
5247
----------
53-
tree_id : int, optional
5448
num_observations : int, optional
5549
"""
5650

57-
def __init__(self, tree_id=0, num_observations=0, m=0):
51+
def __init__(self, num_observations=0, m=0):
5852
self.tree_structure = {}
5953
self.num_nodes = 0
6054
self.idx_leaf_nodes = []
61-
self.idx_prunable_split_nodes = []
62-
self.tree_id = tree_id
6355
self.num_observations = num_observations
6456
self.m = m
6557

@@ -144,12 +136,11 @@ def _traverse_tree(self, x, node_index=0, split_variable=None):
144136
return current_node, split_variable
145137

146138
@staticmethod
147-
def init_tree(tree_id, leaf_node_value, idx_data_points, m):
139+
def init_tree(leaf_node_value, idx_data_points, m):
148140
"""
149141
150142
Parameters
151143
----------
152-
tree_id
153144
leaf_node_value
154145
idx_data_points
155146
m : int
@@ -159,7 +150,7 @@ def init_tree(tree_id, leaf_node_value, idx_data_points, m):
159150
-------
160151
161152
"""
162-
new_tree = Tree(tree_id, len(idx_data_points), m)
153+
new_tree = Tree(len(idx_data_points), m)
163154
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
164155
return new_tree
165156

0 commit comments

Comments
 (0)