@@ -40,8 +40,6 @@ class PGBART(ArrayStepShared):
40
40
List of value variables for sampler
41
41
num_particles : int
42
42
Number of particles for the conditional SMC sampler. Defaults to 40
43
- max_stages : int
44
- Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
45
43
batch : int or tuple
46
44
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
47
45
during tuning and after tuning. If a tuple is passed the first element is the batch size
@@ -59,7 +57,6 @@ def __init__(
59
57
self ,
60
58
vars = None ,
61
59
num_particles = 40 ,
62
- max_stages = 100 ,
63
60
batch = "auto" ,
64
61
model = None ,
65
62
):
@@ -78,7 +75,6 @@ def __init__(
78
75
self .missing_data = np .any (np .isnan (self .X ))
79
76
self .m = self .bart .m
80
77
self .alpha = self .bart .alpha
81
- self .k = self .bart .k
82
78
self .alpha_vec = self .bart .split_prior
83
79
if self .alpha_vec is None :
84
80
self .alpha_vec = np .ones (self .X .shape [1 ])
@@ -87,10 +83,10 @@ def __init__(
87
83
# if data is binary
88
84
Y_unique = np .unique (self .Y )
89
85
if Y_unique .size == 2 and np .all (Y_unique == [0 , 1 ]):
90
- self . mu_std = 6 / ( self .k * self . m ** 0.5 )
86
+ mu_std = 3 / self .m ** 0.5
91
87
# maybe we need to check for count data
92
88
else :
93
- self . mu_std = ( 2 * self .Y .std ()) / ( self .k * self . m ** 0.5 )
89
+ mu_std = self .Y .std () / self .m ** 0.5
94
90
95
91
self .num_observations = self .X .shape [0 ]
96
92
self .num_variates = self .X .shape [1 ]
@@ -103,7 +99,8 @@ def __init__(
103
99
)
104
100
self .mean = fast_mean ()
105
101
106
- self .normal = NormalSampler ()
102
+ self .normal = NormalSampler (mu_std )
103
+ self .uniform = UniformSampler (0.33 , 0.75 )
107
104
self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
108
105
self .ssv = SampleSplittingVariable (self .alpha_vec )
109
106
@@ -121,12 +118,11 @@ def __init__(
121
118
self .log_num_particles = np .log (num_particles )
122
119
self .indices = list (range (2 , num_particles ))
123
120
self .len_indices = len (self .indices )
124
- self .max_stages = max_stages
125
121
126
122
shared = make_shared_replacements (initial_values , vars , model )
127
123
self .likelihood_logp = logp (initial_values , [model .datalogpt ], vars , shared )
128
124
self .all_particles = []
129
- for i in range (self .m ):
125
+ for _ in range (self .m ):
130
126
self .a_tree .leaf_node_value = self .init_mean / self .m
131
127
p = ParticleTree (self .a_tree )
132
128
self .all_particles .append (p )
@@ -138,25 +134,13 @@ def astep(self, _):
138
134
139
135
tree_ids = np .random .choice (range (self .m ), replace = False , size = self .batch [~ self .tune ])
140
136
for tree_id in tree_ids :
137
+ # Compute the sum of trees without the old tree that we are attempting to replace
138
+ self .sum_trees_noi = self .sum_trees - self .all_particles [tree_id ].tree ._predict ()
141
139
# Generate an initial set of SMC particles
142
140
# at the end of the algorithm we return one of these particles as the new tree
143
141
particles = self .init_particles (tree_id )
144
- # Compute the sum of trees without the old tree, that we are attempting to replace
145
- self .sum_trees_noi = self .sum_trees - particles [0 ].tree ._predict ()
146
- # Resample leaf values for particle 1 which is a copy of the old tree
147
- particles [1 ].sample_leafs (
148
- self .sum_trees ,
149
- self .X ,
150
- self .mean ,
151
- self .m ,
152
- self .normal ,
153
- self .mu_std ,
154
- )
155
142
156
- # The old tree and the one with new leafs do not grow so we update the weights only once
157
- self .update_weight (particles [0 ], old = True )
158
- self .update_weight (particles [1 ], old = True )
159
- for _ in range (self .max_stages ):
143
+ while True :
160
144
# Sample each particle (try to grow each tree), except for the first two
161
145
stop_growing = True
162
146
for p in particles [2 :]:
@@ -170,26 +154,26 @@ def astep(self, _):
170
154
self .mean ,
171
155
self .m ,
172
156
self .normal ,
173
- self .mu_std ,
174
157
)
175
158
if tree_grew :
176
159
self .update_weight (p )
177
160
if p .expansion_nodes :
178
161
stop_growing = False
179
162
if stop_growing :
180
163
break
164
+
181
165
# Normalize weights
182
- W_t , normalized_weights = self .normalize (particles [2 :])
166
+ w_t , normalized_weights = self .normalize (particles [2 :])
183
167
184
168
# Resample all but first two particles
185
169
new_indices = np .random .choice (
186
170
self .indices , size = self .len_indices , p = normalized_weights
187
171
)
188
172
particles [2 :] = particles [new_indices ]
189
173
190
- # Set the new weights
174
+ # Set the new weight
191
175
for p in particles [2 :]:
192
- p .log_weight = W_t
176
+ p .log_weight = w_t
193
177
194
178
for p in particles [2 :]:
195
179
p .log_weight = p .old_likelihood_logp
@@ -216,27 +200,45 @@ def astep(self, _):
216
200
return self .sum_trees , [stats ]
217
201
218
202
def normalize (self , particles ):
219
- """Use logsumexp trick to get W_t and softmax to get normalized_weights."""
203
+ """Use logsumexp trick to get w_t and softmax to get normalized_weights.
204
+
205
+ w_t is the un-normalized weight per particle, we will assign it to the
206
+ next round of particles, so they all start with the same weight.
207
+ """
220
208
log_w = np .array ([p .log_weight for p in particles ])
221
209
log_w_max = log_w .max ()
222
210
log_w_ = log_w - log_w_max
223
211
w_ = np .exp (log_w_ )
224
212
w_sum = w_ .sum ()
225
- W_t = log_w_max + np .log (w_sum ) - self .log_num_particles
213
+ w_t = log_w_max + np .log (w_sum ) - self .log_num_particles
226
214
normalized_weights = w_ / w_sum
227
215
# stabilize weights to avoid assigning exactly zero probability to a particle
228
216
normalized_weights += 1e-12
229
217
230
- return W_t , normalized_weights
218
+ return w_t , normalized_weights
231
219
232
220
def init_particles (self , tree_id : int ) -> np .ndarray :
233
221
"""Initialize particles."""
234
- p = self .all_particles [tree_id ]
235
- particles = [p ]
236
- particles .append (copy (p ))
222
+ p0 = self .all_particles [tree_id ]
223
+ p1 = copy (p0 )
224
+ p1 .sample_leafs (
225
+ self .sum_trees ,
226
+ self .mean ,
227
+ self .m ,
228
+ self .normal ,
229
+ )
230
+ # The old tree and the one with new leafs do not grow so we update the weights only once
231
+ self .update_weight (p0 , old = True )
232
+ self .update_weight (p1 , old = True )
237
233
234
+ particles = [p0 , p1 ]
238
235
for _ in self .indices :
239
- particles .append (ParticleTree (self .a_tree ))
236
+ pt = ParticleTree (self .a_tree )
237
+ if self .tune :
238
+ pt .kf = self .uniform .random ()
239
+ else :
240
+ pt .kf = p0 .kf
241
+ particles .append (pt )
240
242
241
243
return np .array (particles )
242
244
@@ -273,6 +275,7 @@ def __init__(self, tree):
273
275
self .log_weight = 0
274
276
self .old_likelihood_logp = 0
275
277
self .used_variates = []
278
+ self .kf = 0.75
276
279
277
280
def sample_tree (
278
281
self ,
@@ -285,7 +288,6 @@ def sample_tree(
285
288
mean ,
286
289
m ,
287
290
normal ,
288
- mu_std ,
289
291
):
290
292
tree_grew = False
291
293
if self .expansion_nodes :
@@ -305,7 +307,7 @@ def sample_tree(
305
307
mean ,
306
308
m ,
307
309
normal ,
308
- mu_std ,
310
+ self . kf ,
309
311
)
310
312
if index_selected_predictor is not None :
311
313
new_indexes = self .tree .idx_leaf_nodes [- 2 :]
@@ -315,9 +317,20 @@ def sample_tree(
315
317
316
318
return tree_grew
317
319
318
- def sample_leafs (self , sum_trees , X , mean , m , normal , mu_std ):
320
+ def sample_leafs (self , sum_trees , mean , m , normal ):
319
321
320
- sample_leaf_values (self .tree , sum_trees , X , mean , m , normal , mu_std )
322
+ for idx in self .tree .idx_leaf_nodes :
323
+ if idx > 0 :
324
+ leaf = self .tree [idx ]
325
+ idx_data_points = leaf .idx_data_points
326
+ node_value = draw_leaf_value (
327
+ sum_trees [idx_data_points ],
328
+ mean ,
329
+ m ,
330
+ normal ,
331
+ self .kf ,
332
+ )
333
+ leaf .value = node_value
321
334
322
335
323
336
class SampleSplittingVariable :
@@ -375,7 +388,7 @@ def grow_tree(
375
388
mean ,
376
389
m ,
377
390
normal ,
378
- mu_std ,
391
+ kf ,
379
392
):
380
393
current_node = tree .get_node (index_leaf_node )
381
394
idx_data_points = current_node .idx_data_points
@@ -406,11 +419,10 @@ def grow_tree(
406
419
idx_data_point = new_idx_data_points [idx ]
407
420
node_value = draw_leaf_value (
408
421
sum_trees [idx_data_point ],
409
- X [idx_data_point , selected_predictor ],
410
422
mean ,
411
423
m ,
412
424
normal ,
413
- mu_std ,
425
+ kf ,
414
426
)
415
427
416
428
new_node = LeafNode (
@@ -435,25 +447,6 @@ def grow_tree(
435
447
return index_selected_predictor
436
448
437
449
438
- def sample_leaf_values (tree , sum_trees , X , mean , m , normal , mu_std ):
439
-
440
- for idx in tree .idx_leaf_nodes :
441
- if idx > 0 :
442
- leaf = tree [idx ]
443
- idx_data_points = leaf .idx_data_points
444
- parent_node = tree [leaf .get_idx_parent_node ()]
445
- selected_predictor = parent_node .idx_split_variable
446
- node_value = draw_leaf_value (
447
- sum_trees [idx_data_points ],
448
- X [idx_data_points , selected_predictor ],
449
- mean ,
450
- m ,
451
- normal ,
452
- mu_std ,
453
- )
454
- leaf .value = node_value
455
-
456
-
457
450
def get_new_idx_data_points (split_value , idx_data_points , selected_predictor , X ):
458
451
459
452
left_idx = X [idx_data_points , selected_predictor ] <= split_value
@@ -463,12 +456,12 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
463
456
return left_node_idx_data_points , right_node_idx_data_points
464
457
465
458
466
- def draw_leaf_value (Y_mu_pred , X_mu , mean , m , normal , mu_std ):
459
+ def draw_leaf_value (Y_mu_pred , mean , m , normal , kf ):
467
460
"""Draw Gaussian distributed leaf values."""
468
461
if Y_mu_pred .size == 0 :
469
462
return 0
470
463
else :
471
- norm = normal .random () * mu_std
464
+ norm = normal .random () * kf
472
465
if Y_mu_pred .size == 1 :
473
466
mu_mean = Y_mu_pred .item () / m
474
467
else :
@@ -507,17 +500,36 @@ def discrete_uniform_sampler(upper_value):
507
500
class NormalSampler :
508
501
"""Cache samples from a standard normal distribution."""
509
502
510
- def __init__ (self ):
503
+ def __init__ (self , scale ):
504
+ self .size = 1000
505
+ self .cache = []
506
+ self .scale = scale
507
+
508
+ def random (self ):
509
+ if not self .cache :
510
+ self .update ()
511
+ return self .cache .pop ()
512
+
513
+ def update (self ):
514
+ self .cache = np .random .normal (loc = 0.0 , scale = self .scale , size = self .size ).tolist ()
515
+
516
+
517
+ class UniformSampler :
518
+ """Cache samples from a uniform distribution."""
519
+
520
+ def __init__ (self , lower_bound , upper_bound ):
511
521
self .size = 1000
512
522
self .cache = []
523
+ self .lower_bound = lower_bound
524
+ self .upper_bound = upper_bound
513
525
514
526
def random (self ):
515
527
if not self .cache :
516
528
self .update ()
517
529
return self .cache .pop ()
518
530
519
531
def update (self ):
520
- self .cache = np .random .normal ( loc = 0.0 , scale = 1 , size = self .size ).tolist ()
532
+ self .cache = np .random .uniform ( self . lower_bound , self . upper_bound , size = self .size ).tolist ()
521
533
522
534
523
535
def logp (point , out_vars , vars , shared ):
0 commit comments