@@ -75,10 +75,13 @@ def __init__(
75
75
self .missing_data = np .any (np .isnan (self .X ))
76
76
self .m = self .bart .m
77
77
self .alpha = self .bart .alpha
78
- self .alpha_vec = self .bart .split_prior
79
- if self .alpha_vec is None :
80
- self .alpha_vec = np .ones (self .X .shape [1 ])
78
+ shape = initial_values [value_bart .name ].shape
79
+ if len (shape ) == 1 :
80
+ self .shape = 1
81
+ else :
82
+ self .shape = shape [0 ]
81
83
84
+ self .alpha_vec = self .bart .split_prior
82
85
self .init_mean = self .Y .mean ()
83
86
# if data is binary
84
87
Y_unique = np .unique (self .Y )
@@ -92,15 +95,19 @@ def __init__(
92
95
self .num_variates = self .X .shape [1 ]
93
96
self .available_predictors = list (range (self .num_variates ))
94
97
95
- self .sum_trees = np .full_like (self .Y , self .init_mean ).astype (aesara .config .floatX )
98
+ self .sum_trees = np .full ((self .shape , self .Y .shape [0 ]), self .init_mean ).astype (
99
+ aesara .config .floatX
100
+ )
101
+
96
102
self .a_tree = Tree .init_tree (
97
103
leaf_node_value = self .init_mean / self .m ,
98
104
idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
105
+ shape = self .shape ,
99
106
)
100
107
self .mean = fast_mean ()
101
108
102
- self .normal = NormalSampler (mu_std )
103
- self .uniform = UniformSampler (0.33 , 0.75 )
109
+ self .normal = NormalSampler (mu_std , self . shape )
110
+ self .uniform = UniformSampler (0.33 , 0.75 , self . shape )
104
111
self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
105
112
self .ssv = SampleSplittingVariable (self .alpha_vec )
106
113
@@ -120,7 +127,7 @@ def __init__(
120
127
self .len_indices = len (self .indices )
121
128
122
129
shared = make_shared_replacements (initial_values , vars , model )
123
- self .likelihood_logp = logp (initial_values , [model .datalogpt ], vars , shared )
130
+ self .likelihood_logp = logp (initial_values , [model .datalogp ], vars , shared )
124
131
self .all_particles = []
125
132
for _ in range (self .m ):
126
133
self .a_tree .leaf_node_value = self .init_mean / self .m
@@ -154,6 +161,7 @@ def astep(self, _):
154
161
self .mean ,
155
162
self .m ,
156
163
self .normal ,
164
+ self .shape ,
157
165
)
158
166
if tree_grew :
159
167
self .update_weight (p )
@@ -226,6 +234,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
226
234
self .mean ,
227
235
self .m ,
228
236
self .normal ,
237
+ self .shape ,
229
238
)
230
239
231
240
# The old tree and the one with new leafs do not grow so we update the weights only once
@@ -250,7 +259,9 @@ def update_weight(self, particle, old=False):
250
259
Since the prior is used as the proposal,the weights are updated additively as the ratio of
251
260
the new and old log-likelihoods.
252
261
"""
253
- new_likelihood = self .likelihood_logp (self .sum_trees_noi + particle .tree ._predict ())
262
+ new_likelihood = self .likelihood_logp (
263
+ (self .sum_trees_noi + particle .tree ._predict ()).flatten ()
264
+ )
254
265
if old :
255
266
particle .log_weight = new_likelihood
256
267
particle .old_likelihood_logp = new_likelihood
@@ -289,6 +300,7 @@ def sample_tree(
289
300
mean ,
290
301
m ,
291
302
normal ,
303
+ shape ,
292
304
):
293
305
tree_grew = False
294
306
if self .expansion_nodes :
@@ -309,6 +321,7 @@ def sample_tree(
309
321
m ,
310
322
normal ,
311
323
self .kf ,
324
+ shape ,
312
325
)
313
326
if index_selected_predictor is not None :
314
327
new_indexes = self .tree .idx_leaf_nodes [- 2 :]
@@ -318,18 +331,19 @@ def sample_tree(
318
331
319
332
return tree_grew
320
333
321
- def sample_leafs (self , sum_trees , mean , m , normal ):
334
+ def sample_leafs (self , sum_trees , mean , m , normal , shape ):
322
335
323
336
for idx in self .tree .idx_leaf_nodes :
324
337
if idx > 0 :
325
338
leaf = self .tree [idx ]
326
339
idx_data_points = leaf .idx_data_points
327
340
node_value = draw_leaf_value (
328
- sum_trees [idx_data_points ],
341
+ sum_trees [:, idx_data_points ],
329
342
mean ,
330
343
m ,
331
344
normal ,
332
345
self .kf ,
346
+ shape ,
333
347
)
334
348
leaf .value = node_value
335
349
@@ -390,6 +404,7 @@ def grow_tree(
390
404
m ,
391
405
normal ,
392
406
kf ,
407
+ shape ,
393
408
):
394
409
current_node = tree .get_node (index_leaf_node )
395
410
idx_data_points = current_node .idx_data_points
@@ -413,11 +428,12 @@ def grow_tree(
413
428
for idx in range (2 ):
414
429
idx_data_point = new_idx_data_points [idx ]
415
430
node_value = draw_leaf_value (
416
- sum_trees [idx_data_point ],
431
+ sum_trees [:, idx_data_point ],
417
432
mean ,
418
433
m ,
419
434
normal ,
420
435
kf ,
436
+ shape ,
421
437
)
422
438
423
439
new_node = LeafNode (
@@ -466,14 +482,14 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data):
466
482
return split_value
467
483
468
484
469
- def draw_leaf_value (Y_mu_pred , mean , m , normal , kf ):
485
+ def draw_leaf_value (Y_mu_pred , mean , m , normal , kf , shape ):
470
486
"""Draw Gaussian distributed leaf values."""
471
487
if Y_mu_pred .size == 0 :
472
- return 0
488
+ return np . zeros ( shape )
473
489
else :
474
490
norm = normal .random () * kf
475
491
if Y_mu_pred .size == 1 :
476
- mu_mean = Y_mu_pred .item () / m
492
+ mu_mean = np . full ( shape , Y_mu_pred .item () / m )
477
493
else :
478
494
mu_mean = mean (Y_mu_pred ) / m
479
495
@@ -486,15 +502,25 @@ def fast_mean():
486
502
try :
487
503
from numba import jit
488
504
except ImportError :
489
- return np .mean
505
+ from functools import partial
506
+
507
+ return partial (np .mean , axis = 1 )
490
508
491
509
@jit
492
510
def mean (a ):
493
- count = a .shape [0 ]
494
- suma = 0
495
- for i in range (count ):
496
- suma += a [i ]
497
- return suma / count
511
+ if a .ndim == 1 :
512
+ count = a .shape [0 ]
513
+ suma = 0
514
+ for i in range (count ):
515
+ suma += a [i ]
516
+ return suma / count
517
+ elif a .ndim == 2 :
518
+ res = np .zeros (a .shape [0 ])
519
+ count = a .shape [1 ]
520
+ for j in range (a .shape [0 ]):
521
+ for i in range (count ):
522
+ res [j ] += a [j , i ]
523
+ return res / count
498
524
499
525
return mean
500
526
@@ -510,36 +536,46 @@ def discrete_uniform_sampler(upper_value):
510
536
class NormalSampler :
511
537
"""Cache samples from a standard normal distribution."""
512
538
513
- def __init__ (self , scale ):
539
+ def __init__ (self , scale , shape ):
514
540
self .size = 1000
515
- self .cache = []
516
541
self .scale = scale
542
+ self .shape = shape
543
+ self .update ()
517
544
518
545
def random (self ):
519
- if not self .cache :
546
+ if self . idx == self .size :
520
547
self .update ()
521
- return self .cache .pop ()
548
+ pop = self .cache [:, self .idx ]
549
+ self .idx += 1
550
+ return pop
522
551
523
552
def update (self ):
524
- self .cache = np .random .normal (loc = 0.0 , scale = self .scale , size = self .size ).tolist ()
553
+ self .idx = 0
554
+ self .cache = np .random .normal (loc = 0.0 , scale = self .scale , size = (self .shape , self .size ))
525
555
526
556
527
557
class UniformSampler :
528
558
"""Cache samples from a uniform distribution."""
529
559
530
- def __init__ (self , lower_bound , upper_bound ):
560
+ def __init__ (self , lower_bound , upper_bound , shape ):
531
561
self .size = 1000
532
- self .cache = []
533
- self .lower_bound = lower_bound
534
562
self .upper_bound = upper_bound
563
+ self .lower_bound = lower_bound
564
+ self .shape = shape
565
+ self .update ()
535
566
536
567
def random (self ):
537
- if not self .cache :
568
+ if self . idx == self .size :
538
569
self .update ()
539
- return self .cache .pop ()
570
+ pop = self .cache [:, self .idx ]
571
+ self .idx += 1
572
+ return pop
540
573
541
574
def update (self ):
542
- self .cache = np .random .uniform (self .lower_bound , self .upper_bound , size = self .size ).tolist ()
575
+ self .idx = 0
576
+ self .cache = np .random .uniform (
577
+ self .lower_bound , self .upper_bound , size = (self .shape , self .size )
578
+ )
543
579
544
580
545
581
def logp (point , out_vars , vars , shared ):
0 commit comments