We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d1982dc commit c1c3a0bCopy full SHA for c1c3a0b
pymc3/tests/test_bart.py
@@ -33,7 +33,7 @@ def test_model():
33
sigma = pm.HalfNormal("sigma", 1)
34
mu = pm.BART("mu", X, Y, m=50)
35
y = pm.Normal("y", mu, sigma, observed=Y)
36
- idata = pm.sample()
+ idata = pm.sample(chains=4)
37
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")
38
39
np.testing.assert_allclose(mean, Y, 0.5)
@@ -43,7 +43,7 @@ def test_model():
43
mu_ = pm.BART("mu_", X, Y, m=50)
44
mu = pm.Deterministic("mu", pm.math.invlogit(mu_))
45
y = pm.Bernoulli("y", mu, observed=Y)
46
47
48
49
np.testing.assert_allclose(mean, Y, atol=0.5)
0 commit comments