1
1
import numpy as np
2
2
3
+ from numpy .random import RandomState
4
+
3
5
import pymc3 as pm
4
6
5
7
@@ -33,7 +35,7 @@ def test_bart_vi():
33
35
mu = pm .BART ("mu" , X , Y , m = 10 )
34
36
sigma = pm .HalfNormal ("sigma" , 1 )
35
37
y = pm .Normal ("y" , mu , sigma , observed = Y )
36
- idata = pm .sample (random_seed = 3415 , chains = 1 )
38
+ idata = pm .sample (random_seed = 3415 )
37
39
var_imp = (
38
40
idata .sample_stats ["variable_inclusion" ]
39
41
.stack (samples = ("chain" , "draw" ))
@@ -42,3 +44,23 @@ def test_bart_vi():
42
44
var_imp /= var_imp .sum ()
43
45
assert var_imp [0 ] > var_imp [1 :].sum ()
44
46
np .testing .assert_almost_equal (var_imp .sum (), 1 )
47
+
48
+
49
+ def test_bart_random ():
50
+ X = np .random .normal (0 , 1 , size = (2 , 50 )).T
51
+ Y = np .random .normal (0 , 1 , size = 50 )
52
+
53
+ with pm .Model () as model :
54
+ mu = pm .BART ("mu" , X , Y , m = 10 )
55
+ sigma = pm .HalfNormal ("sigma" , 1 )
56
+ y = pm .Normal ("y" , mu , sigma , observed = Y )
57
+ idata = pm .sample (random_seed = 3415 , chains = 1 )
58
+
59
+ rng = RandomState (12345 )
60
+ pred_all = mu .owner .op .rng_fn (rng , size = 2 )
61
+ rng = RandomState (12345 )
62
+ pred_first = mu .owner .op .rng_fn (rng , X_new = X [:10 ])
63
+
64
+ assert np .all (pred_first == pred_all [0 , :10 ])
65
+ assert pred_all .shape == (2 , 50 )
66
+ assert pred_first .shape == (10 ,)
0 commit comments