Skip to content

Commit bb69a76

Browse files
committed
test variable importance
1 parent 2dda3b0 commit bb69a76

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

pymc3/tests/test_sampling.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,20 @@ def test_trace_report(self, step_cls, discard):
166166
assert isinstance(trace.report.t_sampling, float)
167167
pass
168168

169+
def test_trace_report_bart():
170+
X = np.random.normal(0, 1, size=(3, 250)).T
171+
Y = np.random.normal(0, 1, size=250)
172+
X[:, 0] = np.random.normal(Y, 0.1)
173+
174+
with pm.Model() as model:
175+
mu = pm.BART("mu", X, Y, m=20)
176+
sigma = pm.HalfNormal("sigma", 1)
177+
y = pm.Normal("y", mu, sigma, observed=Y)
178+
trace = pm.sample(500, tune=100, chains=1, random_seed=3415)
179+
var_imp = trace.report.variable_importance
180+
assert var_imp[0] > var_imp[1:].sum()
181+
npt.assert_almost_equal(var_imp.sum(), 1)
182+
169183
def test_return_inferencedata(self):
170184
with self.model:
171185
kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis())

0 commit comments

Comments
 (0)