We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2dda3b0 commit bb69a76Copy full SHA for bb69a76
pymc3/tests/test_sampling.py
@@ -166,6 +166,20 @@ def test_trace_report(self, step_cls, discard):
166
assert isinstance(trace.report.t_sampling, float)
167
pass
168
169
+ def test_trace_report_bart():
170
+ X = np.random.normal(0, 1, size=(3, 250)).T
171
+ Y = np.random.normal(0, 1, size=250)
172
+ X[:, 0] = np.random.normal(Y, 0.1)
173
+
174
+ with pm.Model() as model:
175
+ mu = pm.BART("mu", X, Y, m=20)
176
+ sigma = pm.HalfNormal("sigma", 1)
177
+ y = pm.Normal("y", mu, sigma, observed=Y)
178
+ trace = pm.sample(500, tune=100, chains=1, random_seed=3415)
179
+ var_imp = trace.report.variable_importance
180
+ assert var_imp[0] > var_imp[1:].sum()
181
+ npt.assert_almost_equal(var_imp.sum(), 1)
182
183
def test_return_inferencedata(self):
184
with self.model:
185
kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis())
0 commit comments