Skip to content

Commit 78b6f79

Browse files
committed
update variable importance report
1 parent 7ac976b commit 78b6f79

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

pymc3/sampling.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -608,10 +608,12 @@ def sample(
608608

609609
if "variable_inclusion" in trace.stat_names:
610610
variable_inclusion = np.vstack(trace.get_sampler_stats("variable_inclusion"))
611-
variable_inclusion = np.split(variable_inclusion, 50)
612-
dada = np.vstack([v.sum(0) / v.sum() for v in variable_inclusion])
613-
trace.report.variable_importance_m = dada.mean(0)
614-
trace.report.variable_importance_s = dada.std(0)
611+
variable_inclusion = np.vstack(
612+
[v.sum(0) / v.sum() for v in np.array_split(variable_inclusion, 50)]
613+
)
614+
trace.report.variable_importance = np.empty((variable_inclusion.shape[1], 3))
615+
trace.report.variable_importance[:, 0] = variable_inclusion.mean(0)
616+
trace.report.variable_importance[:, 1:3] = arviz.hdi(variable_inclusion, hdi_prob=0.68)
615617

616618
n_chains = len(trace.chains)
617619
_log.info(

pymc3/step_methods/pgbart.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
6060
self.idx = 0
6161
if chunk == "auto":
6262
self.chunk = max(1, int(self.bart.m * 0.1))
63-
self.variable_inclusion = np.zeros(self.bart.num_variates, dtype="int")
6463
self.num_particles = num_particles
6564
self.log_num_particles = np.log(num_particles)
6665
self.indices = list(range(1, num_particles))
@@ -77,7 +76,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
7776
def astep(self, _):
7877
bart = self.bart
7978
num_observations = bart.num_observations
80-
variable_inclusion = self.variable_inclusion
79+
variable_inclusion = np.zeros(bart.num_variates, dtype="int")
8180

8281
# For the tunning phase we restrict max_stages to a low number, otherwise it is almost sure
8382
# we will reach max_stages given that our first set of m trees is not good at all.
@@ -142,7 +141,6 @@ def astep(self, _):
142141
bart.sum_trees_output = bart.Y - R_j + new_prediction
143142

144143
if not self.tune:
145-
variable_inclusion = self.variable_inclusion
146144
for index in new_tree.used_variates:
147145
variable_inclusion[index] += 1
148146

0 commit comments

Comments
 (0)