Skip to content

Commit 5023a70

Browse files
aloctavodiatwiecki
authored andcommitted
fix predictions when using 1 chain
1 parent 55463de commit 5023a70

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

pymc/bart/bart.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,10 @@ class BART(NoDistribution):
6262
m : int
6363
Number of trees
6464
alpha : float
65-
Control the prior probability over the depth of the trees. Even when it can takes values in
66-
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
65+
Control the prior probability over the depth of the trees. Defaults to 0.25.
66+
It is recommended to be in the interval (0, 0.5].
6767
k : float
68-
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
69-
and 3.
68+
Scale parameter for the values of the leaf nodes. Defaults to 1.
7069
split_prior : array-like
7170
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
7271
1. Otherwise they will be normalized.

pymc/bart/pgbart.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616

17-
from copy import copy
17+
from copy import copy, deepcopy
1818

1919
import aesara
2020
import numpy as np
@@ -39,12 +39,12 @@ class PGBART(ArrayStepShared):
3939
vars: list
4040
List of value variables for sampler
4141
num_particles : int
42-
Number of particles for the conditional SMC sampler. Defaults to 40
42+
Number of particles. Defaults to 40
4343
max_stages : int
44-
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
44+
Maximum number of iterations. Defaults to 100.
4545
batch : int or tuple
4646
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
47-
during tuning and after tuning. If a tuple is passed the first element is the batch size
47+
during and after tuning. If a tuple is passed the first element is the batch size
4848
during tuning and the second the batch size after tuning.
4949
model: PyMC Model
5050
Optional model for sampling step. Defaults to None (taken from context).
@@ -81,10 +81,10 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo
8181
# if data is binary
8282
Y_unique = np.unique(self.Y)
8383
if Y_unique.size == 2 and np.all(Y_unique == [0, 1]):
84-
self.mu_std = 6 / (self.k * self.m**0.5)
84+
self.mu_std = 3 / (self.k * self.m**0.5)
8585
# maybe we need to check for count data
8686
else:
87-
self.mu_std = (2 * self.Y.std()) / (self.k * self.m**0.5)
87+
self.mu_std = self.Y.std() / (self.k * self.m**0.5)
8888

8989
self.num_observations = self.X.shape[0]
9090
self.num_variates = self.X.shape[1]
@@ -229,8 +229,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
229229
Initialize particles
230230
"""
231231
p = self.all_particles[tree_id]
232-
particles = [p]
233-
particles.append(copy(p))
232+
particles = [p, p.copy()]
234233

235234
for _ in self.indices:
236235
particles.append(ParticleTree(self.a_tree))
@@ -275,6 +274,9 @@ def __init__(self, tree):
275274
self.old_likelihood_logp = 0
276275
self.used_variates = []
277276

277+
def copy(self):
278+
return deepcopy(self)
279+
278280
def sample_tree(
279281
self,
280282
ssv,

0 commit comments

Comments
 (0)