Skip to content

Commit c4284a1

Browse files
authored
Remove k argument (#40)
* refactor * use keys * sample k
1 parent 5def470 commit c4284a1

File tree

3 files changed

+79
-72
lines changed

3 files changed

+79
-72
lines changed

pymc_experimental/bart/bart.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ class BART(NoDistribution):
6868
alpha : float
6969
Control the prior probability over the depth of the trees. Even when it can takes values in
7070
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
71-
k : float
72-
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
73-
and 3.
7471
split_prior : array-like
7572
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
7673
1. Otherwise they will be normalized.
@@ -84,7 +81,6 @@ def __new__(
8481
Y,
8582
m=50,
8683
alpha=0.25,
87-
k=2,
8884
split_prior=None,
8985
**kwargs,
9086
):
@@ -102,7 +98,6 @@ def __new__(
10298
Y=Y,
10399
m=m,
104100
alpha=alpha,
105-
k=k,
106101
split_prior=split_prior,
107102
),
108103
)()
@@ -114,7 +109,7 @@ def get_moment(rv, size, *rv_inputs):
114109
return cls.get_moment(rv, size, *rv_inputs)
115110

116111
cls.rv_op = bart_op
117-
params = [X, Y, m, alpha, k]
112+
params = [X, Y, m, alpha]
118113
return super().__new__(cls, name, *params, **kwargs)
119114

120115
@classmethod

pymc_experimental/bart/pgbart.py

+77-65
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ class PGBART(ArrayStepShared):
4040
List of value variables for sampler
4141
num_particles : int
4242
Number of particles for the conditional SMC sampler. Defaults to 40
43-
max_stages : int
44-
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
4543
batch : int or tuple
4644
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
4745
during tuning and after tuning. If a tuple is passed the first element is the batch size
@@ -59,7 +57,6 @@ def __init__(
5957
self,
6058
vars=None,
6159
num_particles=40,
62-
max_stages=100,
6360
batch="auto",
6461
model=None,
6562
):
@@ -78,7 +75,6 @@ def __init__(
7875
self.missing_data = np.any(np.isnan(self.X))
7976
self.m = self.bart.m
8077
self.alpha = self.bart.alpha
81-
self.k = self.bart.k
8278
self.alpha_vec = self.bart.split_prior
8379
if self.alpha_vec is None:
8480
self.alpha_vec = np.ones(self.X.shape[1])
@@ -87,10 +83,10 @@ def __init__(
8783
# if data is binary
8884
Y_unique = np.unique(self.Y)
8985
if Y_unique.size == 2 and np.all(Y_unique == [0, 1]):
90-
self.mu_std = 6 / (self.k * self.m**0.5)
86+
mu_std = 3 / self.m**0.5
9187
# maybe we need to check for count data
9288
else:
93-
self.mu_std = (2 * self.Y.std()) / (self.k * self.m**0.5)
89+
mu_std = self.Y.std() / self.m**0.5
9490

9591
self.num_observations = self.X.shape[0]
9692
self.num_variates = self.X.shape[1]
@@ -103,7 +99,8 @@ def __init__(
10399
)
104100
self.mean = fast_mean()
105101

106-
self.normal = NormalSampler()
102+
self.normal = NormalSampler(mu_std)
103+
self.uniform = UniformSampler(0.33, 0.75)
107104
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
108105
self.ssv = SampleSplittingVariable(self.alpha_vec)
109106

@@ -121,12 +118,11 @@ def __init__(
121118
self.log_num_particles = np.log(num_particles)
122119
self.indices = list(range(2, num_particles))
123120
self.len_indices = len(self.indices)
124-
self.max_stages = max_stages
125121

126122
shared = make_shared_replacements(initial_values, vars, model)
127123
self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared)
128124
self.all_particles = []
129-
for i in range(self.m):
125+
for _ in range(self.m):
130126
self.a_tree.leaf_node_value = self.init_mean / self.m
131127
p = ParticleTree(self.a_tree)
132128
self.all_particles.append(p)
@@ -138,25 +134,13 @@ def astep(self, _):
138134

139135
tree_ids = np.random.choice(range(self.m), replace=False, size=self.batch[~self.tune])
140136
for tree_id in tree_ids:
137+
# Compute the sum of trees without the old tree that we are attempting to replace
138+
self.sum_trees_noi = self.sum_trees - self.all_particles[tree_id].tree._predict()
141139
# Generate an initial set of SMC particles
142140
# at the end of the algorithm we return one of these particles as the new tree
143141
particles = self.init_particles(tree_id)
144-
# Compute the sum of trees without the old tree, that we are attempting to replace
145-
self.sum_trees_noi = self.sum_trees - particles[0].tree._predict()
146-
# Resample leaf values for particle 1 which is a copy of the old tree
147-
particles[1].sample_leafs(
148-
self.sum_trees,
149-
self.X,
150-
self.mean,
151-
self.m,
152-
self.normal,
153-
self.mu_std,
154-
)
155142

156-
# The old tree and the one with new leafs do not grow so we update the weights only once
157-
self.update_weight(particles[0], old=True)
158-
self.update_weight(particles[1], old=True)
159-
for _ in range(self.max_stages):
143+
while True:
160144
# Sample each particle (try to grow each tree), except for the first two
161145
stop_growing = True
162146
for p in particles[2:]:
@@ -170,26 +154,26 @@ def astep(self, _):
170154
self.mean,
171155
self.m,
172156
self.normal,
173-
self.mu_std,
174157
)
175158
if tree_grew:
176159
self.update_weight(p)
177160
if p.expansion_nodes:
178161
stop_growing = False
179162
if stop_growing:
180163
break
164+
181165
# Normalize weights
182-
W_t, normalized_weights = self.normalize(particles[2:])
166+
w_t, normalized_weights = self.normalize(particles[2:])
183167

184168
# Resample all but first two particles
185169
new_indices = np.random.choice(
186170
self.indices, size=self.len_indices, p=normalized_weights
187171
)
188172
particles[2:] = particles[new_indices]
189173

190-
# Set the new weights
174+
# Set the new weight
191175
for p in particles[2:]:
192-
p.log_weight = W_t
176+
p.log_weight = w_t
193177

194178
for p in particles[2:]:
195179
p.log_weight = p.old_likelihood_logp
@@ -216,27 +200,45 @@ def astep(self, _):
216200
return self.sum_trees, [stats]
217201

218202
def normalize(self, particles):
219-
"""Use logsumexp trick to get W_t and softmax to get normalized_weights."""
203+
"""Use logsumexp trick to get w_t and softmax to get normalized_weights.
204+
205+
w_t is the un-normalized weight per particle, we will assign it to the
206+
next round of particles, so they all start with the same weight.
207+
"""
220208
log_w = np.array([p.log_weight for p in particles])
221209
log_w_max = log_w.max()
222210
log_w_ = log_w - log_w_max
223211
w_ = np.exp(log_w_)
224212
w_sum = w_.sum()
225-
W_t = log_w_max + np.log(w_sum) - self.log_num_particles
213+
w_t = log_w_max + np.log(w_sum) - self.log_num_particles
226214
normalized_weights = w_ / w_sum
227215
# stabilize weights to avoid assigning exactly zero probability to a particle
228216
normalized_weights += 1e-12
229217

230-
return W_t, normalized_weights
218+
return w_t, normalized_weights
231219

232220
def init_particles(self, tree_id: int) -> np.ndarray:
233221
"""Initialize particles."""
234-
p = self.all_particles[tree_id]
235-
particles = [p]
236-
particles.append(copy(p))
222+
p0 = self.all_particles[tree_id]
223+
p1 = copy(p0)
224+
p1.sample_leafs(
225+
self.sum_trees,
226+
self.mean,
227+
self.m,
228+
self.normal,
229+
)
230+
# The old tree and the one with new leafs do not grow so we update the weights only once
231+
self.update_weight(p0, old=True)
232+
self.update_weight(p1, old=True)
237233

234+
particles = [p0, p1]
238235
for _ in self.indices:
239-
particles.append(ParticleTree(self.a_tree))
236+
pt = ParticleTree(self.a_tree)
237+
if self.tune:
238+
pt.kf = self.uniform.random()
239+
else:
240+
pt.kf = p0.kf
241+
particles.append(pt)
240242

241243
return np.array(particles)
242244

@@ -273,6 +275,7 @@ def __init__(self, tree):
273275
self.log_weight = 0
274276
self.old_likelihood_logp = 0
275277
self.used_variates = []
278+
self.kf = 0.75
276279

277280
def sample_tree(
278281
self,
@@ -285,7 +288,6 @@ def sample_tree(
285288
mean,
286289
m,
287290
normal,
288-
mu_std,
289291
):
290292
tree_grew = False
291293
if self.expansion_nodes:
@@ -305,7 +307,7 @@ def sample_tree(
305307
mean,
306308
m,
307309
normal,
308-
mu_std,
310+
self.kf,
309311
)
310312
if index_selected_predictor is not None:
311313
new_indexes = self.tree.idx_leaf_nodes[-2:]
@@ -315,9 +317,20 @@ def sample_tree(
315317

316318
return tree_grew
317319

318-
def sample_leafs(self, sum_trees, X, mean, m, normal, mu_std):
320+
def sample_leafs(self, sum_trees, mean, m, normal):
319321

320-
sample_leaf_values(self.tree, sum_trees, X, mean, m, normal, mu_std)
322+
for idx in self.tree.idx_leaf_nodes:
323+
if idx > 0:
324+
leaf = self.tree[idx]
325+
idx_data_points = leaf.idx_data_points
326+
node_value = draw_leaf_value(
327+
sum_trees[idx_data_points],
328+
mean,
329+
m,
330+
normal,
331+
self.kf,
332+
)
333+
leaf.value = node_value
321334

322335

323336
class SampleSplittingVariable:
@@ -375,7 +388,7 @@ def grow_tree(
375388
mean,
376389
m,
377390
normal,
378-
mu_std,
391+
kf,
379392
):
380393
current_node = tree.get_node(index_leaf_node)
381394
idx_data_points = current_node.idx_data_points
@@ -406,11 +419,10 @@ def grow_tree(
406419
idx_data_point = new_idx_data_points[idx]
407420
node_value = draw_leaf_value(
408421
sum_trees[idx_data_point],
409-
X[idx_data_point, selected_predictor],
410422
mean,
411423
m,
412424
normal,
413-
mu_std,
425+
kf,
414426
)
415427

416428
new_node = LeafNode(
@@ -435,25 +447,6 @@ def grow_tree(
435447
return index_selected_predictor
436448

437449

438-
def sample_leaf_values(tree, sum_trees, X, mean, m, normal, mu_std):
439-
440-
for idx in tree.idx_leaf_nodes:
441-
if idx > 0:
442-
leaf = tree[idx]
443-
idx_data_points = leaf.idx_data_points
444-
parent_node = tree[leaf.get_idx_parent_node()]
445-
selected_predictor = parent_node.idx_split_variable
446-
node_value = draw_leaf_value(
447-
sum_trees[idx_data_points],
448-
X[idx_data_points, selected_predictor],
449-
mean,
450-
m,
451-
normal,
452-
mu_std,
453-
)
454-
leaf.value = node_value
455-
456-
457450
def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X):
458451

459452
left_idx = X[idx_data_points, selected_predictor] <= split_value
@@ -463,12 +456,12 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
463456
return left_node_idx_data_points, right_node_idx_data_points
464457

465458

466-
def draw_leaf_value(Y_mu_pred, X_mu, mean, m, normal, mu_std):
459+
def draw_leaf_value(Y_mu_pred, mean, m, normal, kf):
467460
"""Draw Gaussian distributed leaf values."""
468461
if Y_mu_pred.size == 0:
469462
return 0
470463
else:
471-
norm = normal.random() * mu_std
464+
norm = normal.random() * kf
472465
if Y_mu_pred.size == 1:
473466
mu_mean = Y_mu_pred.item() / m
474467
else:
@@ -507,17 +500,36 @@ def discrete_uniform_sampler(upper_value):
507500
class NormalSampler:
508501
"""Cache samples from a standard normal distribution."""
509502

510-
def __init__(self):
503+
def __init__(self, scale):
504+
self.size = 1000
505+
self.cache = []
506+
self.scale = scale
507+
508+
def random(self):
509+
if not self.cache:
510+
self.update()
511+
return self.cache.pop()
512+
513+
def update(self):
514+
self.cache = np.random.normal(loc=0.0, scale=self.scale, size=self.size).tolist()
515+
516+
517+
class UniformSampler:
518+
"""Cache samples from a uniform distribution."""
519+
520+
def __init__(self, lower_bound, upper_bound):
511521
self.size = 1000
512522
self.cache = []
523+
self.lower_bound = lower_bound
524+
self.upper_bound = upper_bound
513525

514526
def random(self):
515527
if not self.cache:
516528
self.update()
517529
return self.cache.pop()
518530

519531
def update(self):
520-
self.cache = np.random.normal(loc=0.0, scale=1, size=self.size).tolist()
532+
self.cache = np.random.uniform(self.lower_bound, self.upper_bound, size=self.size).tolist()
521533

522534

523535
def logp(point, out_vars, vars, shared):

pymc_experimental/bart/tree.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def trim(self):
7878
a_tree = self.copy()
7979
del a_tree.num_observations
8080
del a_tree.idx_leaf_nodes
81-
for k, v in a_tree.tree_structure.items():
81+
for k in a_tree.tree_structure.keys():
8282
current_node = a_tree[k]
8383
del current_node.depth
8484
if isinstance(current_node, LeafNode):

0 commit comments

Comments
 (0)