Skip to content

Commit 8c59b41

Browse files
authored
BART: clamp first particle to old full tree (#5011)
1 parent 1048b69 commit 8c59b41

File tree

1 file changed

+83
-95
lines changed

1 file changed

+83
-95
lines changed

pymc/step_methods/pgbart.py

Lines changed: 83 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,59 @@
3232
_log = logging.getLogger("pymc")
3333

3434

35+
class ParticleTree:
36+
"""
37+
Particle tree
38+
"""
39+
40+
def __init__(self, tree, log_weight, likelihood):
41+
self.tree = tree.copy() # keeps the tree that we care at the moment
42+
self.expansion_nodes = [0]
43+
self.log_weight = log_weight
44+
self.old_likelihood_logp = likelihood
45+
self.used_variates = []
46+
47+
def sample_tree_sequential(
48+
self,
49+
ssv,
50+
available_predictors,
51+
prior_prob_leaf_node,
52+
X,
53+
missing_data,
54+
sum_trees_output,
55+
mean,
56+
m,
57+
normal,
58+
mu_std,
59+
):
60+
tree_grew = False
61+
if self.expansion_nodes:
62+
index_leaf_node = self.expansion_nodes.pop(0)
63+
# Probability that this node will remain a leaf node
64+
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]
65+
66+
if prob_leaf < np.random.random():
67+
tree_grew, index_selected_predictor = grow_tree(
68+
self.tree,
69+
index_leaf_node,
70+
ssv,
71+
available_predictors,
72+
X,
73+
missing_data,
74+
sum_trees_output,
75+
mean,
76+
m,
77+
normal,
78+
mu_std,
79+
)
80+
if tree_grew:
81+
new_indexes = self.tree.idx_leaf_nodes[-2:]
82+
self.expansion_nodes.extend(new_indexes)
83+
self.used_variates.append(index_selected_predictor)
84+
85+
return tree_grew
86+
87+
3588
class PGBART(ArrayStepShared):
3689
"""
3790
Particle Gibss BART sampling step
@@ -108,9 +161,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
108161

109162
if self.chunk == "auto":
110163
self.chunk = max(1, int(self.m * 0.1))
111-
self.num_particles = num_particles
112164
self.log_num_particles = np.log(num_particles)
113165
self.indices = list(range(1, num_particles))
166+
self.len_indices = len(self.indices)
114167
self.max_stages = max_stages
115168

116169
shared = make_shared_replacements(initial_values, vars, model)
@@ -137,24 +190,22 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
137190
if self.idx == self.m:
138191
self.idx = 0
139192

140-
for idx in range(self.idx, self.idx + self.chunk):
141-
if idx >= self.m:
193+
for tree_id in range(self.idx, self.idx + self.chunk):
194+
if tree_id >= self.m:
142195
break
143-
tree = self.all_particles[idx].tree
144-
sum_trees_output_noi = sum_trees_output - tree.predict_output()
145-
self.idx += 1
146196
# Generate an initial set of SMC particles
147197
# at the end of the algorithm we return one of these particles as the new tree
148-
particles = self.init_particles(tree.tree_id)
198+
particles = self.init_particles(tree_id)
199+
# Compute the sum of trees without the tree we are attempting to replace
200+
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()
201+
self.idx += 1
149202

203+
# The old tree is not growing so we update the weights only once.
204+
self.update_weight(particles[0])
150205
for t in range(self.max_stages):
151-
# Get old particle at stage t
152-
if t > 0:
153-
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
154-
# sample each particle (try to grow each tree)
155-
compute_logp = [True]
206+
# Sample each particle (try to grow each tree), except for the first one.
156207
for p in particles[1:]:
157-
clp = p.sample_tree_sequential(
208+
tree_grew = p.sample_tree_sequential(
158209
self.ssv,
159210
self.available_predictors,
160211
self.prior_prob_leaf_node,
@@ -166,22 +217,14 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
166217
self.normal,
167218
self.mu_std,
168219
)
169-
compute_logp.append(clp)
170-
# Update weights. Since the prior is used as the proposal,the weights
171-
# are updated additively as the ratio of the new and old log_likelihoods
172-
for clp, p in zip(compute_logp, particles):
173-
if clp: # Compute the likelihood when p has changed from the previous iteration
174-
new_likelihood = self.likelihood_logp(
175-
sum_trees_output_noi + p.tree.predict_output()
176-
)
177-
p.log_weight += new_likelihood - p.old_likelihood_logp
178-
p.old_likelihood_logp = new_likelihood
220+
if tree_grew:
221+
self.update_weight(p)
179222
# Normalize weights
180223
W_t, normalized_weights = self.normalize(particles)
181224

182225
# Resample all but first particle
183226
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
184-
new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w)
227+
new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w)
185228
particles[1:] = particles[new_indices]
186229

187230
# Set the new weights
@@ -200,8 +243,8 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
200243
new_particle = np.random.choice(particles, p=normalized_weights)
201244
new_tree = new_particle.tree
202245
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
203-
self.all_particles[tree.tree_id] = new_particle
204-
sum_trees_output = sum_trees_output_noi + new_tree.predict_output()
246+
self.all_particles[tree_id] = new_particle
247+
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()
205248

206249
if self.tune:
207250
for index in new_particle.used_variates:
@@ -232,7 +275,7 @@ def competence(var, has_grad):
232275
return Competence.IDEAL
233276
return Competence.INCOMPATIBLE
234277

235-
def normalize(self, particles):
278+
def normalize(self, particles: List[ParticleTree]) -> Tuple[float, np.ndarray]:
236279
"""
237280
Use logsumexp trick to get W_t and softmax to get normalized_weights
238281
"""
@@ -248,16 +291,11 @@ def normalize(self, particles):
248291

249292
return W_t, normalized_weights
250293

251-
def get_old_tree_particle(self, tree_id, t):
252-
old_tree_particle = self.all_particles[tree_id]
253-
old_tree_particle.set_particle_to_step(t)
254-
return old_tree_particle
255-
256-
def init_particles(self, tree_id):
294+
def init_particles(self, tree_id: int) -> np.ndarray:
257295
"""
258296
Initialize particles
259297
"""
260-
p = self.get_old_tree_particle(tree_id, 0)
298+
p = self.all_particles[tree_id]
261299
p.log_weight = self.init_log_weight
262300
p.old_likelihood_logp = self.init_likelihood
263301
particles = [p]
@@ -274,68 +312,18 @@ def init_particles(self, tree_id):
274312

275313
return np.array(particles)
276314

315+
def update_weight(self, particle: List[ParticleTree]) -> None:
316+
"""
317+
Update the weight of a particle
277318
278-
class ParticleTree:
279-
"""
280-
Particle tree
281-
"""
282-
283-
def __init__(self, tree, log_weight, likelihood):
284-
self.tree = tree.copy() # keeps the tree that we care at the moment
285-
self.expansion_nodes = [0]
286-
self.tree_history = [self.tree]
287-
self.expansion_nodes_history = [self.expansion_nodes]
288-
self.log_weight = log_weight
289-
self.old_likelihood_logp = likelihood
290-
self.used_variates = []
291-
292-
def sample_tree_sequential(
293-
self,
294-
ssv,
295-
available_predictors,
296-
prior_prob_leaf_node,
297-
X,
298-
missing_data,
299-
sum_trees_output,
300-
mean,
301-
m,
302-
normal,
303-
mu_std,
304-
):
305-
clp = False
306-
if self.expansion_nodes:
307-
index_leaf_node = self.expansion_nodes.pop(0)
308-
# Probability that this node will remain a leaf node
309-
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]
310-
311-
if prob_leaf < np.random.random():
312-
clp, index_selected_predictor = grow_tree(
313-
self.tree,
314-
index_leaf_node,
315-
ssv,
316-
available_predictors,
317-
X,
318-
missing_data,
319-
sum_trees_output,
320-
mean,
321-
m,
322-
normal,
323-
mu_std,
324-
)
325-
if clp:
326-
new_indexes = self.tree.idx_leaf_nodes[-2:]
327-
self.expansion_nodes.extend(new_indexes)
328-
self.used_variates.append(index_selected_predictor)
329-
330-
self.tree_history.append(self.tree)
331-
self.expansion_nodes_history.append(self.expansion_nodes)
332-
return clp
333-
334-
def set_particle_to_step(self, t):
335-
if len(self.tree_history) <= t:
336-
t = -1
337-
self.tree = self.tree_history[t]
338-
self.expansion_nodes = self.expansion_nodes_history[t]
319+
Since the prior is used as the proposal,the weights are updated additively as the ratio of
320+
the new and old log-likelihoods.
321+
"""
322+
new_likelihood = self.likelihood_logp(
323+
self.sum_trees_output_noi + particle.tree.predict_output()
324+
)
325+
particle.log_weight += new_likelihood - particle.old_likelihood_logp
326+
particle.old_likelihood_logp = new_likelihood
339327

340328

341329
def preprocess_XY(X, Y):

0 commit comments

Comments
 (0)