Skip to content

Commit 6f206bd

Browse files
committed
add tests
1 parent 2dc8b69 commit 6f206bd

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

pymc/bart/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def predict(idata, rng, X_new=None, size=None):
5151

5252

5353
def plot_dependence(
54-
bart_trees,
54+
idata,
5555
X=None,
5656
Y=None,
5757
kind="pdp",
@@ -78,8 +78,8 @@ def plot_dependence(
7878
7979
Parameters
8080
----------
81-
bart_trees: DataArray of trees
82-
BART trees
81+
idata: InferenceData
82+
InferenceData containing a collection of BART_trees in sample_stats group
8383
X : array-like
8484
The covariate matrix.
8585
Y : array-like
@@ -144,6 +144,7 @@ def plot_dependence(
144144
)
145145

146146
rng = RandomState(seed=random_seed)
147+
bart_trees = idata.sample_stats.bart_trees
147148

148149
if isinstance(X, pd.DataFrame):
149150
X_names = list(X.columns)

pymc/tests/test_bart.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
from numpy.random import RandomState
45
from numpy.testing import assert_almost_equal
@@ -47,33 +48,47 @@ def test_bart_vi():
4748
assert_almost_equal(var_imp.sum(), 1)
4849

4950

50-
def test_bart_random():
51+
def test_missing_data():
5152
X = np.random.normal(0, 1, size=(2, 50)).T
5253
Y = np.random.normal(0, 1, size=50)
54+
X[10:20, 0] = np.nan
5355

5456
with pm.Model() as model:
5557
mu = pm.BART("mu", X, Y, m=10)
5658
sigma = pm.HalfNormal("sigma", 1)
5759
y = pm.Normal("y", mu, sigma, observed=Y)
58-
idata = pm.sample(random_seed=3415, chains=1)
59-
60-
rng = RandomState(12345)
61-
pred_all = pm.bart.utils.predict(idata, rng, size=2)
62-
rng = RandomState(12345)
63-
pred_first = pm.bart.utils.predict(idata, rng, X_new=X[:10])
64-
65-
assert_almost_equal(pred_first, pred_all[0, :10], decimal=4)
66-
assert pred_all.shape == (2, 50)
67-
assert pred_first.shape == (10,)
60+
idata = pm.sample(random_seed=3415)
6861

6962

70-
def test_missing_data():
63+
def test_utils():
7164
X = np.random.normal(0, 1, size=(2, 50)).T
7265
Y = np.random.normal(0, 1, size=50)
73-
X[10:20, 0] = np.nan
7466

7567
with pm.Model() as model:
7668
mu = pm.BART("mu", X, Y, m=10)
7769
sigma = pm.HalfNormal("sigma", 1)
7870
y = pm.Normal("y", mu, sigma, observed=Y)
7971
idata = pm.sample(random_seed=3415)
72+
73+
def test_predict():
74+
rng = RandomState(12345)
75+
pred_all = pm.bart.utils.predict(idata, rng, size=2)
76+
rng = RandomState(12345)
77+
pred_first = pm.bart.utils.predict(idata, rng, X_new=X[:10])
78+
79+
assert_almost_equal(pred_first, pred_all[0, :10], decimal=4)
80+
assert pred_all.shape == (2, 50)
81+
assert pred_first.shape == (10,)
82+
83+
@pytest.mark.parametrize(
84+
"kwargs",
85+
[
86+
{},
87+
{"kind": "pdp", "xs_interval": "quantiles", "xs_values": [0.25, 0.5, 0.75]},
88+
{"kind": "ice", "instances": 5},
89+
{"var_idx": 0, "rug": False, "smooth": False, "color": "k"},
90+
{"grid": (1, 2), "sharey": "False", "alpha": 1},
91+
],
92+
)
93+
def test_pdp():
94+
pm.bart.utils.plot_dependence(idata, X=None, Y=None, **kwargs)

0 commit comments

Comments
 (0)