Skip to content

Commit d1982dc

Browse files
committed
add tests
1 parent 14d2128 commit d1982dc

File tree

2 files changed

+69
-19
lines changed

2 files changed

+69
-19
lines changed

pymc3/tests/test_bart.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
3+
import pymc3 as pm
4+
5+
6+
def test_split_node():
7+
split_node = pm.distributions.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
8+
assert split_node.index == 5
9+
assert split_node.idx_split_variable == 2
10+
assert split_node.split_value == 3.0
11+
assert split_node.depth == 2
12+
assert split_node.get_idx_parent_node() == 2
13+
assert split_node.get_idx_left_child() == 11
14+
assert split_node.get_idx_right_child() == 12
15+
16+
17+
def test_leaf_node():
18+
leaf_node = pm.distributions.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3])
19+
assert leaf_node.index == 5
20+
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])
21+
assert leaf_node.value == 3.14
22+
assert leaf_node.get_idx_parent_node() == 2
23+
assert leaf_node.get_idx_left_child() == 11
24+
assert leaf_node.get_idx_right_child() == 12
25+
26+
27+
def test_model():
28+
X = np.linspace(7, 15, 100)
29+
Y = np.sin(np.random.normal(X, 0.2)) + 3
30+
X = X[:, None]
31+
32+
with pm.Model() as model:
33+
sigma = pm.HalfNormal("sigma", 1)
34+
mu = pm.BART("mu", X, Y, m=50)
35+
y = pm.Normal("y", mu, sigma, observed=Y)
36+
idata = pm.sample()
37+
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")
38+
39+
np.testing.assert_allclose(mean, Y, 0.5)
40+
41+
Y = np.repeat([0, 1], 50)
42+
with pm.Model() as model:
43+
mu_ = pm.BART("mu_", X, Y, m=50)
44+
mu = pm.Deterministic("mu", pm.math.invlogit(mu_))
45+
y = pm.Bernoulli("y", mu, observed=Y)
46+
idata = pm.sample()
47+
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")
48+
49+
np.testing.assert_allclose(mean, Y, atol=0.5)
50+
51+
52+
def test_bart_vi():
53+
X = np.random.normal(0, 1, size=(3, 250)).T
54+
Y = np.random.normal(0, 1, size=250)
55+
X[:, 0] = np.random.normal(Y, 0.1)
56+
57+
with pm.Model() as model:
58+
mu = pm.BART("mu", X, Y, m=10)
59+
sigma = pm.HalfNormal("sigma", 1)
60+
y = pm.Normal("y", mu, sigma, observed=Y)
61+
idata = pm.sample(random_seed=3415, chains=1)
62+
var_imp = (
63+
idata.sample_stats["variable_inclusion"]
64+
.stack(samples=("chain", "draw"))
65+
.mean("samples")
66+
)
67+
var_imp /= var_imp.sum()
68+
assert var_imp[0] > var_imp[1:].sum()
69+
np.testing.assert_almost_equal(var_imp.sum(), 1)

pymc3/tests/test_sampling.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -174,25 +174,6 @@ def test_trace_report(self, step_cls, discard):
174174
assert trace.report.n_draws == 100
175175
assert isinstance(trace.report.t_sampling, float)
176176

177-
def test_bart_vi(self):
178-
X = np.random.normal(0, 1, size=(3, 250)).T
179-
Y = np.random.normal(0, 1, size=250)
180-
X[:, 0] = np.random.normal(Y, 0.1)
181-
182-
with pm.Model() as model:
183-
mu = pm.BART("mu", X, Y, m=10)
184-
sigma = pm.HalfNormal("sigma", 1)
185-
y = pm.Normal("y", mu, sigma, observed=Y)
186-
idata = pm.sample(random_seed=3415, chains=1)
187-
var_imp = (
188-
idata.sample_stats["variable_inclusion"]
189-
.stack(samples=("chain", "draw"))
190-
.mean("samples")
191-
)
192-
var_imp /= var_imp.sum()
193-
assert var_imp[0] > var_imp[1:].sum()
194-
npt.assert_almost_equal(var_imp.sum(), 1)
195-
196177
def test_return_inferencedata(self):
197178
with self.model:
198179
kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis())

0 commit comments

Comments
 (0)