32
32
33
33
class PGBART (ArrayStepShared ):
34
34
"""
35
- Particle Gibss BART sampling step
35
+ Particle Gibss BART sampling step.
36
36
37
37
Parameters
38
38
----------
@@ -208,9 +208,7 @@ def astep(self, _):
208
208
return self .sum_trees , [stats ]
209
209
210
210
def normalize (self , particles ):
211
- """
212
- Use logsumexp trick to get W_t and softmax to get normalized_weights
213
- """
211
+ """Use logsumexp trick to get W_t and softmax to get normalized_weights."""
214
212
log_w = np .array ([p .log_weight for p in particles ])
215
213
log_w_max = log_w .max ()
216
214
log_w_ = log_w - log_w_max
@@ -224,9 +222,7 @@ def normalize(self, particles):
224
222
return W_t , normalized_weights
225
223
226
224
def init_particles (self , tree_id : int ) -> np .ndarray :
227
- """
228
- Initialize particles
229
- """
225
+ """Initialize particles."""
230
226
p = self .all_particles [tree_id ]
231
227
particles = [p ]
232
228
particles .append (copy (p ))
@@ -238,7 +234,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
238
234
239
235
def update_weight (self , particle , old = False ):
240
236
"""
241
- Update the weight of a particle
237
+ Update the weight of a particle.
242
238
243
239
Since the prior is used as the proposal,the weights are updated additively as the ratio of
244
240
the new and old log-likelihoods.
@@ -253,19 +249,15 @@ def update_weight(self, particle, old=False):
253
249
254
250
@staticmethod
255
251
def competence (var , has_grad ):
256
- """
257
- PGBART is only suitable for BART distributions
258
- """
252
+ """PGBART is only suitable for BART distributions."""
259
253
dist = getattr (var .owner , "op" , None )
260
254
if isinstance (dist , BARTRV ):
261
255
return Competence .IDEAL
262
256
return Competence .INCOMPATIBLE
263
257
264
258
265
259
class ParticleTree :
266
- """
267
- Particle tree
268
- """
260
+ """Particle tree."""
269
261
270
262
def __init__ (self , tree ):
271
263
self .tree = tree .copy () # keeps the tree that we care at the moment
@@ -340,6 +332,7 @@ def rvs(self):
340
332
def compute_prior_probability (alpha ):
341
333
"""
342
334
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
335
+
343
336
Taken from equation 19 in [Rockova2018].
344
337
345
338
Parameters
@@ -463,7 +456,7 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
463
456
464
457
465
458
def draw_leaf_value (Y_mu_pred , X_mu , mean , m , normal , mu_std ):
466
- """Draw Gaussian distributed leaf values"""
459
+ """Draw Gaussian distributed leaf values. """
467
460
if Y_mu_pred .size == 0 :
468
461
return 0
469
462
else :
@@ -504,9 +497,7 @@ def discrete_uniform_sampler(upper_value):
504
497
505
498
506
499
class NormalSampler :
507
- """
508
- Cache samples from a standard normal distribution
509
- """
500
+ """Cache samples from a standard normal distribution."""
510
501
511
502
def __init__ (self ):
512
503
self .size = 1000
0 commit comments