32
32
_log = logging .getLogger ("pymc" )
33
33
34
34
35
+ class ParticleTree :
36
+ """
37
+ Particle tree
38
+ """
39
+
40
+ def __init__ (self , tree , log_weight , likelihood ):
41
+ self .tree = tree .copy () # keeps the tree that we care at the moment
42
+ self .expansion_nodes = [0 ]
43
+ self .log_weight = log_weight
44
+ self .old_likelihood_logp = likelihood
45
+ self .used_variates = []
46
+
47
+ def sample_tree_sequential (
48
+ self ,
49
+ ssv ,
50
+ available_predictors ,
51
+ prior_prob_leaf_node ,
52
+ X ,
53
+ missing_data ,
54
+ sum_trees_output ,
55
+ mean ,
56
+ m ,
57
+ normal ,
58
+ mu_std ,
59
+ ):
60
+ tree_grew = False
61
+ if self .expansion_nodes :
62
+ index_leaf_node = self .expansion_nodes .pop (0 )
63
+ # Probability that this node will remain a leaf node
64
+ prob_leaf = prior_prob_leaf_node [self .tree [index_leaf_node ].depth ]
65
+
66
+ if prob_leaf < np .random .random ():
67
+ tree_grew , index_selected_predictor = grow_tree (
68
+ self .tree ,
69
+ index_leaf_node ,
70
+ ssv ,
71
+ available_predictors ,
72
+ X ,
73
+ missing_data ,
74
+ sum_trees_output ,
75
+ mean ,
76
+ m ,
77
+ normal ,
78
+ mu_std ,
79
+ )
80
+ if tree_grew :
81
+ new_indexes = self .tree .idx_leaf_nodes [- 2 :]
82
+ self .expansion_nodes .extend (new_indexes )
83
+ self .used_variates .append (index_selected_predictor )
84
+
85
+ return tree_grew
86
+
87
+
35
88
class PGBART (ArrayStepShared ):
36
89
"""
37
90
Particle Gibss BART sampling step
@@ -108,9 +161,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
108
161
109
162
if self .chunk == "auto" :
110
163
self .chunk = max (1 , int (self .m * 0.1 ))
111
- self .num_particles = num_particles
112
164
self .log_num_particles = np .log (num_particles )
113
165
self .indices = list (range (1 , num_particles ))
166
+ self .len_indices = len (self .indices )
114
167
self .max_stages = max_stages
115
168
116
169
shared = make_shared_replacements (initial_values , vars , model )
@@ -137,24 +190,22 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
137
190
if self .idx == self .m :
138
191
self .idx = 0
139
192
140
- for idx in range (self .idx , self .idx + self .chunk ):
141
- if idx >= self .m :
193
+ for tree_id in range (self .idx , self .idx + self .chunk ):
194
+ if tree_id >= self .m :
142
195
break
143
- tree = self .all_particles [idx ].tree
144
- sum_trees_output_noi = sum_trees_output - tree .predict_output ()
145
- self .idx += 1
146
196
# Generate an initial set of SMC particles
147
197
# at the end of the algorithm we return one of these particles as the new tree
148
- particles = self .init_particles (tree .tree_id )
198
+ particles = self .init_particles (tree_id )
199
+ # Compute the sum of trees without the tree we are attempting to replace
200
+ self .sum_trees_output_noi = sum_trees_output - particles [0 ].tree .predict_output ()
201
+ self .idx += 1
149
202
203
+ # The old tree is not growing so we update the weights only once.
204
+ self .update_weight (particles [0 ])
150
205
for t in range (self .max_stages ):
151
- # Get old particle at stage t
152
- if t > 0 :
153
- particles [0 ] = self .get_old_tree_particle (tree .tree_id , t )
154
- # sample each particle (try to grow each tree)
155
- compute_logp = [True ]
206
+ # Sample each particle (try to grow each tree), except for the first one.
156
207
for p in particles [1 :]:
157
- clp = p .sample_tree_sequential (
208
+ tree_grew = p .sample_tree_sequential (
158
209
self .ssv ,
159
210
self .available_predictors ,
160
211
self .prior_prob_leaf_node ,
@@ -166,22 +217,14 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
166
217
self .normal ,
167
218
self .mu_std ,
168
219
)
169
- compute_logp .append (clp )
170
- # Update weights. Since the prior is used as the proposal,the weights
171
- # are updated additively as the ratio of the new and old log_likelihoods
172
- for clp , p in zip (compute_logp , particles ):
173
- if clp : # Compute the likelihood when p has changed from the previous iteration
174
- new_likelihood = self .likelihood_logp (
175
- sum_trees_output_noi + p .tree .predict_output ()
176
- )
177
- p .log_weight += new_likelihood - p .old_likelihood_logp
178
- p .old_likelihood_logp = new_likelihood
220
+ if tree_grew :
221
+ self .update_weight (p )
179
222
# Normalize weights
180
223
W_t , normalized_weights = self .normalize (particles )
181
224
182
225
# Resample all but first particle
183
226
re_n_w = normalized_weights [1 :] / normalized_weights [1 :].sum ()
184
- new_indices = np .random .choice (self .indices , size = len ( self .indices ) , p = re_n_w )
227
+ new_indices = np .random .choice (self .indices , size = self .len_indices , p = re_n_w )
185
228
particles [1 :] = particles [new_indices ]
186
229
187
230
# Set the new weights
@@ -200,8 +243,8 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
200
243
new_particle = np .random .choice (particles , p = normalized_weights )
201
244
new_tree = new_particle .tree
202
245
new_particle .log_weight = new_particle .old_likelihood_logp - self .log_num_particles
203
- self .all_particles [tree . tree_id ] = new_particle
204
- sum_trees_output = sum_trees_output_noi + new_tree .predict_output ()
246
+ self .all_particles [tree_id ] = new_particle
247
+ sum_trees_output = self . sum_trees_output_noi + new_tree .predict_output ()
205
248
206
249
if self .tune :
207
250
for index in new_particle .used_variates :
@@ -232,7 +275,7 @@ def competence(var, has_grad):
232
275
return Competence .IDEAL
233
276
return Competence .INCOMPATIBLE
234
277
235
- def normalize (self , particles ) :
278
+ def normalize (self , particles : List [ ParticleTree ]) -> Tuple [ float , np . ndarray ] :
236
279
"""
237
280
Use logsumexp trick to get W_t and softmax to get normalized_weights
238
281
"""
@@ -248,16 +291,11 @@ def normalize(self, particles):
248
291
249
292
return W_t , normalized_weights
250
293
251
- def get_old_tree_particle (self , tree_id , t ):
252
- old_tree_particle = self .all_particles [tree_id ]
253
- old_tree_particle .set_particle_to_step (t )
254
- return old_tree_particle
255
-
256
- def init_particles (self , tree_id ):
294
+ def init_particles (self , tree_id : int ) -> np .ndarray :
257
295
"""
258
296
Initialize particles
259
297
"""
260
- p = self .get_old_tree_particle ( tree_id , 0 )
298
+ p = self .all_particles [ tree_id ]
261
299
p .log_weight = self .init_log_weight
262
300
p .old_likelihood_logp = self .init_likelihood
263
301
particles = [p ]
@@ -274,68 +312,18 @@ def init_particles(self, tree_id):
274
312
275
313
return np .array (particles )
276
314
315
+ def update_weight (self , particle : List [ParticleTree ]) -> None :
316
+ """
317
+ Update the weight of a particle
277
318
278
- class ParticleTree :
279
- """
280
- Particle tree
281
- """
282
-
283
- def __init__ (self , tree , log_weight , likelihood ):
284
- self .tree = tree .copy () # keeps the tree that we care at the moment
285
- self .expansion_nodes = [0 ]
286
- self .tree_history = [self .tree ]
287
- self .expansion_nodes_history = [self .expansion_nodes ]
288
- self .log_weight = log_weight
289
- self .old_likelihood_logp = likelihood
290
- self .used_variates = []
291
-
292
- def sample_tree_sequential (
293
- self ,
294
- ssv ,
295
- available_predictors ,
296
- prior_prob_leaf_node ,
297
- X ,
298
- missing_data ,
299
- sum_trees_output ,
300
- mean ,
301
- m ,
302
- normal ,
303
- mu_std ,
304
- ):
305
- clp = False
306
- if self .expansion_nodes :
307
- index_leaf_node = self .expansion_nodes .pop (0 )
308
- # Probability that this node will remain a leaf node
309
- prob_leaf = prior_prob_leaf_node [self .tree [index_leaf_node ].depth ]
310
-
311
- if prob_leaf < np .random .random ():
312
- clp , index_selected_predictor = grow_tree (
313
- self .tree ,
314
- index_leaf_node ,
315
- ssv ,
316
- available_predictors ,
317
- X ,
318
- missing_data ,
319
- sum_trees_output ,
320
- mean ,
321
- m ,
322
- normal ,
323
- mu_std ,
324
- )
325
- if clp :
326
- new_indexes = self .tree .idx_leaf_nodes [- 2 :]
327
- self .expansion_nodes .extend (new_indexes )
328
- self .used_variates .append (index_selected_predictor )
329
-
330
- self .tree_history .append (self .tree )
331
- self .expansion_nodes_history .append (self .expansion_nodes )
332
- return clp
333
-
334
- def set_particle_to_step (self , t ):
335
- if len (self .tree_history ) <= t :
336
- t = - 1
337
- self .tree = self .tree_history [t ]
338
- self .expansion_nodes = self .expansion_nodes_history [t ]
319
+ Since the prior is used as the proposal,the weights are updated additively as the ratio of
320
+ the new and old log-likelihoods.
321
+ """
322
+ new_likelihood = self .likelihood_logp (
323
+ self .sum_trees_output_noi + particle .tree .predict_output ()
324
+ )
325
+ particle .log_weight += new_likelihood - particle .old_likelihood_logp
326
+ particle .old_likelihood_logp = new_likelihood
339
327
340
328
341
329
def preprocess_XY (X , Y ):
0 commit comments