|
1 | 1 | import numpy as np
|
| 2 | +import pytest |
2 | 3 |
|
3 | 4 | from numpy.random import RandomState
|
4 | 5 | from numpy.testing import assert_almost_equal
|
@@ -47,33 +48,47 @@ def test_bart_vi():
|
47 | 48 | assert_almost_equal(var_imp.sum(), 1)
|
48 | 49 |
|
49 | 50 |
|
50 |
| -def test_bart_random(): |
| 51 | +def test_missing_data(): |
51 | 52 | X = np.random.normal(0, 1, size=(2, 50)).T
|
52 | 53 | Y = np.random.normal(0, 1, size=50)
|
| 54 | + X[10:20, 0] = np.nan |
53 | 55 |
|
54 | 56 | with pm.Model() as model:
|
55 | 57 | mu = pm.BART("mu", X, Y, m=10)
|
56 | 58 | sigma = pm.HalfNormal("sigma", 1)
|
57 | 59 | y = pm.Normal("y", mu, sigma, observed=Y)
|
58 |
| - idata = pm.sample(random_seed=3415, chains=1) |
59 |
| - |
60 |
| - rng = RandomState(12345) |
61 |
| - pred_all = pm.bart.utils.predict(idata, rng, size=2) |
62 |
| - rng = RandomState(12345) |
63 |
| - pred_first = pm.bart.utils.predict(idata, rng, X_new=X[:10]) |
64 |
| - |
65 |
| - assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) |
66 |
| - assert pred_all.shape == (2, 50) |
67 |
| - assert pred_first.shape == (10,) |
| 60 | + idata = pm.sample(random_seed=3415) |
68 | 61 |
|
69 | 62 |
|
70 |
| -def test_missing_data(): |
| 63 | +def test_utils(): |
71 | 64 | X = np.random.normal(0, 1, size=(2, 50)).T
|
72 | 65 | Y = np.random.normal(0, 1, size=50)
|
73 |
| - X[10:20, 0] = np.nan |
74 | 66 |
|
75 | 67 | with pm.Model() as model:
|
76 | 68 | mu = pm.BART("mu", X, Y, m=10)
|
77 | 69 | sigma = pm.HalfNormal("sigma", 1)
|
78 | 70 | y = pm.Normal("y", mu, sigma, observed=Y)
|
79 | 71 | idata = pm.sample(random_seed=3415)
|
| 72 | + |
| 73 | + def test_predict(): |
| 74 | + rng = RandomState(12345) |
| 75 | + pred_all = pm.bart.utils.predict(idata, rng, size=2) |
| 76 | + rng = RandomState(12345) |
| 77 | + pred_first = pm.bart.utils.predict(idata, rng, X_new=X[:10]) |
| 78 | + |
| 79 | + assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) |
| 80 | + assert pred_all.shape == (2, 50) |
| 81 | + assert pred_first.shape == (10,) |
| 82 | + |
| 83 | + @pytest.mark.parametrize( |
| 84 | + "kwargs", |
| 85 | + [ |
| 86 | + {}, |
| 87 | + {"kind": "pdp", "xs_interval": "quantiles", "xs_values": [0.25, 0.5, 0.75]}, |
| 88 | + {"kind": "ice", "instances": 5}, |
| 89 | + {"var_idx": 0, "rug": False, "smooth": False, "color": "k"}, |
| 90 | + {"grid": (1, 2), "sharey": "False", "alpha": 1}, |
| 91 | + ], |
| 92 | + ) |
| 93 | + def test_pdp(): |
| 94 | + pm.bart.utils.plot_dependence(idata, X=None, Y=None, **kwargs) |
0 commit comments