Skip to content

Commit efa8dc6

Browse files
committed
test random
1 parent 6cebc10 commit efa8dc6

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

pymc3/tests/test_bart.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22

3+
from numpy.random import RandomState
4+
35
import pymc3 as pm
46

57

@@ -33,7 +35,7 @@ def test_bart_vi():
3335
mu = pm.BART("mu", X, Y, m=10)
3436
sigma = pm.HalfNormal("sigma", 1)
3537
y = pm.Normal("y", mu, sigma, observed=Y)
36-
idata = pm.sample(random_seed=3415, chains=1)
38+
idata = pm.sample(random_seed=3415)
3739
var_imp = (
3840
idata.sample_stats["variable_inclusion"]
3941
.stack(samples=("chain", "draw"))
@@ -42,3 +44,23 @@ def test_bart_vi():
4244
var_imp /= var_imp.sum()
4345
assert var_imp[0] > var_imp[1:].sum()
4446
np.testing.assert_almost_equal(var_imp.sum(), 1)
47+
48+
49+
def test_bart_random():
50+
X = np.random.normal(0, 1, size=(2, 50)).T
51+
Y = np.random.normal(0, 1, size=50)
52+
53+
with pm.Model() as model:
54+
mu = pm.BART("mu", X, Y, m=10)
55+
sigma = pm.HalfNormal("sigma", 1)
56+
y = pm.Normal("y", mu, sigma, observed=Y)
57+
idata = pm.sample(random_seed=3415, chains=1)
58+
59+
rng = RandomState(12345)
60+
pred_all = mu.owner.op.rng_fn(rng, size=2)
61+
rng = RandomState(12345)
62+
pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10])
63+
64+
assert np.all(pred_first == pred_all[0, :10])
65+
assert pred_all.shape == (2, 50)
66+
assert pred_first.shape == (10,)

0 commit comments

Comments
 (0)