Skip to content

Commit c1c3a0b

Browse files
committed
fix number of chains
1 parent d1982dc commit c1c3a0b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc3/tests/test_bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_model():
3333
sigma = pm.HalfNormal("sigma", 1)
3434
mu = pm.BART("mu", X, Y, m=50)
3535
y = pm.Normal("y", mu, sigma, observed=Y)
36-
idata = pm.sample()
36+
idata = pm.sample(chains=4)
3737
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")
3838

3939
np.testing.assert_allclose(mean, Y, 0.5)
@@ -43,7 +43,7 @@ def test_model():
4343
mu_ = pm.BART("mu_", X, Y, m=50)
4444
mu = pm.Deterministic("mu", pm.math.invlogit(mu_))
4545
y = pm.Bernoulli("y", mu, observed=Y)
46-
idata = pm.sample()
46+
idata = pm.sample(chains=4)
4747
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")
4848

4949
np.testing.assert_allclose(mean, Y, atol=0.5)

0 commit comments

Comments
 (0)