Skip to content

Commit 44505f6

Browse files
authored
BART: add shape argument (#46)
* add shape argument * swicth dimensions * clean samplers * use size
1 parent 29ee733 commit 44505f6

File tree

5 files changed

+172
-122
lines changed

5 files changed

+172
-122
lines changed

pymc_experimental/bart/bart.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy as np
1717

1818
from aeppl.logprob import _logprob
19-
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
19+
from aesara.tensor.random.op import RandomVariable
2020
from pandas import DataFrame, Series
2121

2222
from pymc.distributions.distribution import NoDistribution, _moment
@@ -29,23 +29,20 @@ class BARTRV(RandomVariable):
2929

3030
name = "BART"
3131
ndim_supp = 1
32-
ndims_params = [2, 1, 0, 0, 0, 1]
32+
ndims_params = [2, 1, 0, 0, 1]
3333
dtype = "floatX"
3434
_print_name = ("BART", "\\operatorname{BART}")
3535
all_trees = None
3636

37-
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
38-
return default_supp_shape_from_params(
39-
self.ndim_supp, dist_params, rep_param_idx, param_shapes
40-
)
41-
42-
def _infer_shape(cls, size, dist_params, param_shapes=None):
43-
dist_shape = (cls.X.shape[0],)
44-
return dist_shape
37+
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
38+
return (self.X.shape[0],)
4539

4640
@classmethod
47-
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
48-
return np.full_like(cls.Y, cls.Y.mean())
41+
def rng_fn(cls, rng, X, Y, m, alpha, split_prior, size):
42+
if size is not None:
43+
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
44+
else:
45+
return np.full(cls.Y.shape[0], cls.Y.mean())
4946

5047

5148
bart = BARTRV()
@@ -87,6 +84,9 @@ def __new__(
8784

8885
X, Y = preprocess_XY(X, Y)
8986

87+
if split_prior is None:
88+
split_prior = np.ones(X.shape[1])
89+
9090
bart_op = type(
9191
f"BART_{name}",
9292
(BARTRV,),
@@ -109,7 +109,7 @@ def get_moment(rv, size, *rv_inputs):
109109
return cls.get_moment(rv, size, *rv_inputs)
110110

111111
cls.rv_op = bart_op
112-
params = [X, Y, m, alpha]
112+
params = [X, Y, m, alpha, split_prior]
113113
return super().__new__(cls, name, *params, **kwargs)
114114

115115
@classmethod
@@ -141,7 +141,6 @@ def preprocess_XY(X, Y):
141141
Y = Y.to_numpy()
142142
if isinstance(X, (Series, DataFrame)):
143143
X = X.to_numpy()
144-
# X = np.random.normal(X, X.std(0)/100)
145144
Y = Y.astype(float)
146145
X = X.astype(float)
147146
return X, Y

pymc_experimental/bart/pgbart.py

+67-31
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,13 @@ def __init__(
7575
self.missing_data = np.any(np.isnan(self.X))
7676
self.m = self.bart.m
7777
self.alpha = self.bart.alpha
78-
self.alpha_vec = self.bart.split_prior
79-
if self.alpha_vec is None:
80-
self.alpha_vec = np.ones(self.X.shape[1])
78+
shape = initial_values[value_bart.name].shape
79+
if len(shape) == 1:
80+
self.shape = 1
81+
else:
82+
self.shape = shape[0]
8183

84+
self.alpha_vec = self.bart.split_prior
8285
self.init_mean = self.Y.mean()
8386
# if data is binary
8487
Y_unique = np.unique(self.Y)
@@ -92,15 +95,19 @@ def __init__(
9295
self.num_variates = self.X.shape[1]
9396
self.available_predictors = list(range(self.num_variates))
9497

95-
self.sum_trees = np.full_like(self.Y, self.init_mean).astype(aesara.config.floatX)
98+
self.sum_trees = np.full((self.shape, self.Y.shape[0]), self.init_mean).astype(
99+
aesara.config.floatX
100+
)
101+
96102
self.a_tree = Tree.init_tree(
97103
leaf_node_value=self.init_mean / self.m,
98104
idx_data_points=np.arange(self.num_observations, dtype="int32"),
105+
shape=self.shape,
99106
)
100107
self.mean = fast_mean()
101108

102-
self.normal = NormalSampler(mu_std)
103-
self.uniform = UniformSampler(0.33, 0.75)
109+
self.normal = NormalSampler(mu_std, self.shape)
110+
self.uniform = UniformSampler(0.33, 0.75, self.shape)
104111
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
105112
self.ssv = SampleSplittingVariable(self.alpha_vec)
106113

@@ -120,7 +127,7 @@ def __init__(
120127
self.len_indices = len(self.indices)
121128

122129
shared = make_shared_replacements(initial_values, vars, model)
123-
self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared)
130+
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
124131
self.all_particles = []
125132
for _ in range(self.m):
126133
self.a_tree.leaf_node_value = self.init_mean / self.m
@@ -154,6 +161,7 @@ def astep(self, _):
154161
self.mean,
155162
self.m,
156163
self.normal,
164+
self.shape,
157165
)
158166
if tree_grew:
159167
self.update_weight(p)
@@ -226,6 +234,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
226234
self.mean,
227235
self.m,
228236
self.normal,
237+
self.shape,
229238
)
230239

231240
# The old tree and the one with new leafs do not grow so we update the weights only once
@@ -250,7 +259,9 @@ def update_weight(self, particle, old=False):
250259
Since the prior is used as the proposal,the weights are updated additively as the ratio of
251260
the new and old log-likelihoods.
252261
"""
253-
new_likelihood = self.likelihood_logp(self.sum_trees_noi + particle.tree._predict())
262+
new_likelihood = self.likelihood_logp(
263+
(self.sum_trees_noi + particle.tree._predict()).flatten()
264+
)
254265
if old:
255266
particle.log_weight = new_likelihood
256267
particle.old_likelihood_logp = new_likelihood
@@ -289,6 +300,7 @@ def sample_tree(
289300
mean,
290301
m,
291302
normal,
303+
shape,
292304
):
293305
tree_grew = False
294306
if self.expansion_nodes:
@@ -309,6 +321,7 @@ def sample_tree(
309321
m,
310322
normal,
311323
self.kf,
324+
shape,
312325
)
313326
if index_selected_predictor is not None:
314327
new_indexes = self.tree.idx_leaf_nodes[-2:]
@@ -318,18 +331,19 @@ def sample_tree(
318331

319332
return tree_grew
320333

321-
def sample_leafs(self, sum_trees, mean, m, normal):
334+
def sample_leafs(self, sum_trees, mean, m, normal, shape):
322335

323336
for idx in self.tree.idx_leaf_nodes:
324337
if idx > 0:
325338
leaf = self.tree[idx]
326339
idx_data_points = leaf.idx_data_points
327340
node_value = draw_leaf_value(
328-
sum_trees[idx_data_points],
341+
sum_trees[:, idx_data_points],
329342
mean,
330343
m,
331344
normal,
332345
self.kf,
346+
shape,
333347
)
334348
leaf.value = node_value
335349

@@ -390,6 +404,7 @@ def grow_tree(
390404
m,
391405
normal,
392406
kf,
407+
shape,
393408
):
394409
current_node = tree.get_node(index_leaf_node)
395410
idx_data_points = current_node.idx_data_points
@@ -413,11 +428,12 @@ def grow_tree(
413428
for idx in range(2):
414429
idx_data_point = new_idx_data_points[idx]
415430
node_value = draw_leaf_value(
416-
sum_trees[idx_data_point],
431+
sum_trees[:, idx_data_point],
417432
mean,
418433
m,
419434
normal,
420435
kf,
436+
shape,
421437
)
422438

423439
new_node = LeafNode(
@@ -466,14 +482,14 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data):
466482
return split_value
467483

468484

469-
def draw_leaf_value(Y_mu_pred, mean, m, normal, kf):
485+
def draw_leaf_value(Y_mu_pred, mean, m, normal, kf, shape):
470486
"""Draw Gaussian distributed leaf values."""
471487
if Y_mu_pred.size == 0:
472-
return 0
488+
return np.zeros(shape)
473489
else:
474490
norm = normal.random() * kf
475491
if Y_mu_pred.size == 1:
476-
mu_mean = Y_mu_pred.item() / m
492+
mu_mean = np.full(shape, Y_mu_pred.item() / m)
477493
else:
478494
mu_mean = mean(Y_mu_pred) / m
479495

@@ -486,15 +502,25 @@ def fast_mean():
486502
try:
487503
from numba import jit
488504
except ImportError:
489-
return np.mean
505+
from functools import partial
506+
507+
return partial(np.mean, axis=1)
490508

491509
@jit
492510
def mean(a):
493-
count = a.shape[0]
494-
suma = 0
495-
for i in range(count):
496-
suma += a[i]
497-
return suma / count
511+
if a.ndim == 1:
512+
count = a.shape[0]
513+
suma = 0
514+
for i in range(count):
515+
suma += a[i]
516+
return suma / count
517+
elif a.ndim == 2:
518+
res = np.zeros(a.shape[0])
519+
count = a.shape[1]
520+
for j in range(a.shape[0]):
521+
for i in range(count):
522+
res[j] += a[j, i]
523+
return res / count
498524

499525
return mean
500526

@@ -510,36 +536,46 @@ def discrete_uniform_sampler(upper_value):
510536
class NormalSampler:
511537
"""Cache samples from a standard normal distribution."""
512538

513-
def __init__(self, scale):
539+
def __init__(self, scale, shape):
514540
self.size = 1000
515-
self.cache = []
516541
self.scale = scale
542+
self.shape = shape
543+
self.update()
517544

518545
def random(self):
519-
if not self.cache:
546+
if self.idx == self.size:
520547
self.update()
521-
return self.cache.pop()
548+
pop = self.cache[:, self.idx]
549+
self.idx += 1
550+
return pop
522551

523552
def update(self):
524-
self.cache = np.random.normal(loc=0.0, scale=self.scale, size=self.size).tolist()
553+
self.idx = 0
554+
self.cache = np.random.normal(loc=0.0, scale=self.scale, size=(self.shape, self.size))
525555

526556

527557
class UniformSampler:
528558
"""Cache samples from a uniform distribution."""
529559

530-
def __init__(self, lower_bound, upper_bound):
560+
def __init__(self, lower_bound, upper_bound, shape):
531561
self.size = 1000
532-
self.cache = []
533-
self.lower_bound = lower_bound
534562
self.upper_bound = upper_bound
563+
self.lower_bound = lower_bound
564+
self.shape = shape
565+
self.update()
535566

536567
def random(self):
537-
if not self.cache:
568+
if self.idx == self.size:
538569
self.update()
539-
return self.cache.pop()
570+
pop = self.cache[:, self.idx]
571+
self.idx += 1
572+
return pop
540573

541574
def update(self):
542-
self.cache = np.random.uniform(self.lower_bound, self.upper_bound, size=self.size).tolist()
575+
self.idx = 0
576+
self.cache = np.random.uniform(
577+
self.lower_bound, self.upper_bound, size=(self.shape, self.size)
578+
)
543579

544580

545581
def logp(point, out_vars, vars, shared):

pymc_experimental/bart/tree.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,13 @@ class Tree:
4646
num_observations : int, optional
4747
"""
4848

49-
def __init__(self, num_observations=0):
49+
def __init__(self, num_observations=0, shape=1):
5050
self.tree_structure = {}
5151
self.idx_leaf_nodes = []
52-
self.num_observations = num_observations
52+
self.shape = shape
53+
self.output = (
54+
np.zeros((num_observations, self.shape)).astype(aesara.config.floatX).squeeze()
55+
)
5356

5457
def __getitem__(self, index):
5558
return self.get_node(index)
@@ -74,7 +77,7 @@ def delete_leaf_node(self, index):
7477

7578
def trim(self):
7679
a_tree = self.copy()
77-
del a_tree.num_observations
80+
del a_tree.output
7881
del a_tree.idx_leaf_nodes
7982
for k in a_tree.tree_structure.keys():
8083
current_node = a_tree[k]
@@ -84,12 +87,11 @@ def trim(self):
8487
return a_tree
8588

8689
def _predict(self):
87-
output = np.zeros(self.num_observations)
90+
output = self.output
8891
for node_index in self.idx_leaf_nodes:
8992
leaf_node = self.get_node(node_index)
9093
output[leaf_node.idx_data_points] = leaf_node.value
91-
92-
return output.astype(aesara.config.floatX)
94+
return output.T
9395

9496
def predict(self, X, excluded=None):
9597
"""
@@ -110,7 +112,7 @@ def predict(self, X, excluded=None):
110112
if excluded is not None:
111113
parent_node = leaf_node.get_idx_parent_node()
112114
if self.get_node(parent_node).idx_split_variable in excluded:
113-
leaf_value = 0.0
115+
leaf_value = np.zeros(self.shape)
114116
return leaf_value
115117

116118
def _traverse_tree(self, x, node_index=0):
@@ -137,7 +139,7 @@ def _traverse_tree(self, x, node_index=0):
137139
return current_node
138140

139141
@staticmethod
140-
def init_tree(leaf_node_value, idx_data_points):
142+
def init_tree(leaf_node_value, idx_data_points, shape):
141143
"""
142144
Initialize tree.
143145
@@ -150,7 +152,7 @@ def init_tree(leaf_node_value, idx_data_points):
150152
-------
151153
tree
152154
"""
153-
new_tree = Tree(len(idx_data_points))
155+
new_tree = Tree(len(idx_data_points), shape)
154156
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
155157
return new_tree
156158

0 commit comments

Comments
 (0)