28
28
_log = logging .getLogger ("pymc" )
29
29
30
30
31
- class ParticleTree :
32
- """
33
- Particle tree
34
- """
35
-
36
- def __init__ (self , tree ):
37
- self .tree = tree .copy () # keeps the tree that we care at the moment
38
- self .expansion_nodes = [0 ]
39
- self .used_variates = []
40
-
41
- def sample_tree (
42
- self ,
43
- ssv ,
44
- available_predictors ,
45
- prior_prob_leaf_node ,
46
- X ,
47
- missing_data ,
48
- sum_trees ,
49
- mean ,
50
- linear_fit ,
51
- m ,
52
- normal ,
53
- mu_std ,
54
- response ,
55
- ):
56
- if self .expansion_nodes :
57
- index_leaf_node = self .expansion_nodes .pop (0 )
58
- # Probability that this node will remain a leaf node
59
- prob_leaf = prior_prob_leaf_node [self .tree [index_leaf_node ].depth ]
60
-
61
- if prob_leaf < np .random .random ():
62
- tree_grew , index_selected_predictor = grow_tree (
63
- self .tree ,
64
- index_leaf_node ,
65
- ssv ,
66
- available_predictors ,
67
- X ,
68
- missing_data ,
69
- sum_trees ,
70
- mean ,
71
- linear_fit ,
72
- m ,
73
- normal ,
74
- mu_std ,
75
- response ,
76
- )
77
- if tree_grew :
78
- new_indexes = self .tree .idx_leaf_nodes [- 2 :]
79
- self .expansion_nodes .extend (new_indexes )
80
- self .used_variates .append (index_selected_predictor )
81
-
82
-
83
31
class PGBART (ArrayStepShared ):
84
32
"""
85
33
Particle Gibss BART sampling step
@@ -138,7 +86,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
138
86
139
87
self .sum_trees = np .full_like (self .Y , self .init_mean * self .m ).astype (aesara .config .floatX )
140
88
self .a_tree = Tree .init_tree (
141
- tree_id = 0 ,
142
89
leaf_node_value = self .init_mean ,
143
90
idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
144
91
m = self .m ,
@@ -162,14 +109,13 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
162
109
163
110
self .log_num_particles = np .log (num_particles )
164
111
self .indices = list (range (1 , num_particles ))
165
- self .log_likelihoods = np .empty (num_particles )
112
+ self .log_weights = np .empty (num_particles )
166
113
self .max_stages = max_stages
167
114
168
115
shared = make_shared_replacements (initial_values , vars , model )
169
116
self .likelihood_logp = logp (initial_values , [model .datalogpt ], vars , shared )
170
117
self .all_particles = []
171
- for i in range (self .m ):
172
- self .a_tree .tree_id = i
118
+ for _ in range (self .m ):
173
119
p = ParticleTree (self .a_tree )
174
120
self .all_particles .append (p )
175
121
self .all_trees = np .array ([p .tree for p in self .all_particles ])
@@ -184,7 +130,6 @@ def astep(self, _):
184
130
if self .iter >= self .m * 3 :
185
131
for tree_id in tree_ids :
186
132
self .sum_trees = self .sum_trees - self .all_particles [tree_id ].tree .predict_output ()
187
- self .a_tree .tree_id = tree_id
188
133
p = ParticleTree (self .a_tree )
189
134
self .all_particles [tree_id ] = p
190
135
self .all_trees [tree_id ] = p .tree
@@ -196,13 +141,15 @@ def astep(self, _):
196
141
# Generate an initial set of particles
197
142
# at the end of the algorithm we return one of these particles as the new tree
198
143
particles = self .init_particles (tree_id )
144
+ # update weight previous tree
145
+ self .log_weights [0 ] = self .likelihood_logp (self .sum_trees )
199
146
# Compute the sum of trees without the tree we are attempting to replace
200
147
self .sum_trees_noi = self .sum_trees - particles [0 ].tree .predict_output ()
201
148
202
149
for _ in range (self .max_stages ):
203
150
# Sample each particle (try to grow each tree), except for the first one.
204
151
stop_growing = 1
205
- for p in particles [1 :]:
152
+ for idx , p in enumerate ( particles [1 :]) :
206
153
p .sample_tree (
207
154
self .ssv ,
208
155
self .available_predictors ,
@@ -219,22 +166,20 @@ def astep(self, _):
219
166
)
220
167
if p .expansion_nodes :
221
168
stop_growing = 0
169
+ else :
170
+ self .log_weights [idx + 1 ] = self .likelihood_logp (
171
+ self .sum_trees_noi + p .tree .predict_output ()
172
+ )
222
173
if stop_growing :
223
174
break
224
175
225
- for idx , p in enumerate (particles ):
226
- self .log_likelihoods [idx ] = self .likelihood_logp (
227
- self .sum_trees_noi + p .tree .predict_output ()
228
- )
229
-
230
- normalized_weights = normalize (self .log_likelihoods )
176
+ normalized_weights = normalize_weights (self .log_weights )
231
177
# Get the new tree and update
232
178
new_particle = np .random .choice (particles , p = normalized_weights )
233
179
new_tree = new_particle .tree
234
180
self .all_trees [tree_id ] = new_tree
235
181
self .all_particles [tree_id ] = new_particle
236
- new_pred = new_tree .predict_output ()
237
- self .sum_trees = self .sum_trees_noi + new_pred
182
+ self .sum_trees = self .sum_trees_noi + new_tree .predict_output ()
238
183
239
184
if self .tune :
240
185
self .ssv = SampleSplittingVariable (self .alpha_vec )
@@ -270,12 +215,12 @@ def init_particles(self, tree_id: int) -> np.ndarray:
270
215
return np .array (particles )
271
216
272
217
273
- def normalize ( log_likelihoods ):
218
+ def normalize_weights ( log_weights ):
274
219
"""
275
220
Use softmax to get normalized_weights
276
221
"""
277
- log_w_max = log_likelihoods .max ()
278
- log_w_ = log_likelihoods - log_w_max
222
+ log_w_max = log_weights .max ()
223
+ log_w_ = log_weights - log_w_max
279
224
w_ = np .exp (log_w_ )
280
225
normalized_weights = w_ / w_ .sum ()
281
226
# stabilize weights to avoid assigning exactly zero probability to a particle
@@ -284,6 +229,58 @@ def normalize(log_likelihoods):
284
229
return normalized_weights
285
230
286
231
232
+ class ParticleTree :
233
+ """
234
+ Particle tree
235
+ """
236
+
237
+ def __init__ (self , tree ):
238
+ self .tree = tree .copy () # keeps the tree that we care at the moment
239
+ self .expansion_nodes = [0 ]
240
+ self .used_variates = []
241
+
242
+ def sample_tree (
243
+ self ,
244
+ ssv ,
245
+ available_predictors ,
246
+ prior_prob_leaf_node ,
247
+ X ,
248
+ missing_data ,
249
+ sum_trees ,
250
+ mean ,
251
+ linear_fit ,
252
+ m ,
253
+ normal ,
254
+ mu_std ,
255
+ response ,
256
+ ):
257
+ if self .expansion_nodes :
258
+ index_leaf_node = self .expansion_nodes .pop (0 )
259
+ # Probability that this node will remain a leaf node
260
+ prob_leaf = prior_prob_leaf_node [self .tree [index_leaf_node ].depth ]
261
+
262
+ if prob_leaf < np .random .random ():
263
+ tree_grew , index_selected_predictor = grow_tree (
264
+ self .tree ,
265
+ index_leaf_node ,
266
+ ssv ,
267
+ available_predictors ,
268
+ X ,
269
+ missing_data ,
270
+ sum_trees ,
271
+ mean ,
272
+ linear_fit ,
273
+ m ,
274
+ normal ,
275
+ mu_std ,
276
+ response ,
277
+ )
278
+ if tree_grew :
279
+ new_indexes = self .tree .idx_leaf_nodes [- 2 :]
280
+ self .expansion_nodes .extend (new_indexes )
281
+ self .used_variates .append (index_selected_predictor )
282
+
283
+
287
284
class SampleSplittingVariable :
288
285
def __init__ (self , alpha_vec ):
289
286
"""
@@ -401,13 +398,6 @@ def grow_tree(
401
398
tree .set_node (index_leaf_node , new_split_node )
402
399
tree .set_node (new_nodes [0 ].index , new_nodes [0 ])
403
400
tree .set_node (new_nodes [1 ].index , new_nodes [1 ])
404
- # The new SplitNode is a prunable node since it has both children.
405
- tree .idx_prunable_split_nodes .append (index_leaf_node )
406
- # If the parent of the node from which the tree is growing was a prunable node,
407
- # remove from the list since one of its children is a SplitNode now
408
- parent_index = current_node .get_idx_parent_node ()
409
- if parent_index in tree .idx_prunable_split_nodes :
410
- tree .idx_prunable_split_nodes .remove (parent_index )
411
401
412
402
return True , index_selected_predictor
413
403
0 commit comments