19
19
20
20
import pymc_extras as pmx
21
21
22
- from pymc_extras .inference .find_map import find_MAP
22
+ from pymc_extras .inference .find_map import GradientBackend , find_MAP
23
23
from pymc_extras .inference .laplace import (
24
24
fit_laplace ,
25
- fit_mvn_to_MAP ,
25
+ fit_mvn_at_MAP ,
26
26
sample_laplace_posterior ,
27
27
)
28
28
@@ -37,7 +37,11 @@ def rng():
37
37
"ignore:hessian will stop negating the output in a future version of PyMC.\n "
38
38
+ "To suppress this warning set `negate_output=False`:FutureWarning" ,
39
39
)
40
- def test_laplace ():
40
+ @pytest .mark .parametrize (
41
+ "mode, gradient_backend" ,
42
+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
43
+ )
44
+ def test_laplace (mode , gradient_backend : GradientBackend ):
41
45
# Example originates from Bayesian Data Analyses, 3rd Edition
42
46
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
43
47
# Aki Vehtari, and Donald Rubin.
@@ -55,7 +59,13 @@ def test_laplace():
55
59
vars = [mu , logsigma ]
56
60
57
61
idata = pmx .fit (
58
- method = "laplace" , optimize_method = "trust-ncg" , draws = draws , random_seed = 173300 , chains = 1
62
+ method = "laplace" ,
63
+ optimize_method = "trust-ncg" ,
64
+ draws = draws ,
65
+ random_seed = 173300 ,
66
+ chains = 1 ,
67
+ compile_kwargs = {"mode" : mode },
68
+ gradient_backend = gradient_backend ,
59
69
)
60
70
61
71
assert idata .posterior ["mu" ].shape == (1 , draws )
@@ -71,7 +81,11 @@ def test_laplace():
71
81
np .testing .assert_allclose (idata .fit ["covariance_matrix" ].values , bda_cov , atol = 1e-4 )
72
82
73
83
74
- def test_laplace_only_fit ():
84
+ @pytest .mark .parametrize (
85
+ "mode, gradient_backend" ,
86
+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
87
+ )
88
+ def test_laplace_only_fit (mode , gradient_backend : GradientBackend ):
75
89
# Example originates from Bayesian Data Analyses, 3rd Edition
76
90
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
77
91
# Aki Vehtari, and Donald Rubin.
@@ -90,8 +104,8 @@ def test_laplace_only_fit():
90
104
method = "laplace" ,
91
105
optimize_method = "BFGS" ,
92
106
progressbar = True ,
93
- gradient_backend = "jax" ,
94
- compile_kwargs = {"mode" : "JAX" },
107
+ gradient_backend = gradient_backend ,
108
+ compile_kwargs = {"mode" : mode },
95
109
optimizer_kwargs = dict (maxiter = 100_000 , gtol = 1e-100 ),
96
110
random_seed = 173300 ,
97
111
)
@@ -111,8 +125,11 @@ def test_laplace_only_fit():
111
125
[True , False ],
112
126
ids = ["transformed" , "untransformed" ],
113
127
)
114
- @pytest .mark .parametrize ("mode" , ["JAX" , None ], ids = ["jax" , "pytensor" ])
115
- def test_fit_laplace_coords (rng , transform_samples , mode ):
128
+ @pytest .mark .parametrize (
129
+ "mode, gradient_backend" ,
130
+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
131
+ )
132
+ def test_fit_laplace_coords (rng , transform_samples , mode , gradient_backend : GradientBackend ):
116
133
coords = {"city" : ["A" , "B" , "C" ], "obs_idx" : np .arange (100 )}
117
134
with pm .Model (coords = coords ) as model :
118
135
mu = pm .Normal ("mu" , mu = 3 , sigma = 0.5 , dims = ["city" ])
@@ -131,13 +148,13 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
131
148
use_hessp = True ,
132
149
progressbar = False ,
133
150
compile_kwargs = dict (mode = mode ),
134
- gradient_backend = "jax" if mode == "JAX" else "pytensor" ,
151
+ gradient_backend = gradient_backend ,
135
152
)
136
153
137
154
for value in optimized_point .values ():
138
155
assert value .shape == (3 ,)
139
156
140
- mu , H_inv = fit_mvn_to_MAP (
157
+ mu , H_inv = fit_mvn_at_MAP (
141
158
optimized_point = optimized_point ,
142
159
model = model ,
143
160
transform_samples = transform_samples ,
@@ -163,7 +180,11 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
163
180
]
164
181
165
182
166
- def test_fit_laplace_ragged_coords (rng ):
183
+ @pytest .mark .parametrize (
184
+ "mode, gradient_backend" ,
185
+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
186
+ )
187
+ def test_fit_laplace_ragged_coords (mode , gradient_backend : GradientBackend , rng ):
167
188
coords = {"city" : ["A" , "B" , "C" ], "feature" : [0 , 1 ], "obs_idx" : np .arange (100 )}
168
189
with pm .Model (coords = coords ) as ragged_dim_model :
169
190
X = pm .Data ("X" , np .ones ((100 , 2 )), dims = ["obs_idx" , "feature" ])
@@ -188,8 +209,8 @@ def test_fit_laplace_ragged_coords(rng):
188
209
progressbar = False ,
189
210
use_grad = True ,
190
211
use_hessp = True ,
191
- gradient_backend = "jax" ,
192
- compile_kwargs = {"mode" : "JAX" },
212
+ gradient_backend = gradient_backend ,
213
+ compile_kwargs = {"mode" : mode },
193
214
)
194
215
195
216
assert idata ["posterior" ].beta .shape [- 2 :] == (3 , 2 )
@@ -206,7 +227,11 @@ def test_fit_laplace_ragged_coords(rng):
206
227
[True , False ],
207
228
ids = ["transformed" , "untransformed" ],
208
229
)
209
- def test_fit_laplace (fit_in_unconstrained_space ):
230
+ @pytest .mark .parametrize (
231
+ "mode, gradient_backend" ,
232
+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
233
+ )
234
+ def test_fit_laplace (fit_in_unconstrained_space , mode , gradient_backend : GradientBackend ):
210
235
with pm .Model () as simp_model :
211
236
mu = pm .Normal ("mu" , mu = 3 , sigma = 0.5 )
212
237
sigma = pm .Exponential ("sigma" , 1 )
@@ -223,6 +248,8 @@ def test_fit_laplace(fit_in_unconstrained_space):
223
248
use_hessp = True ,
224
249
fit_in_unconstrained_space = fit_in_unconstrained_space ,
225
250
optimizer_kwargs = dict (maxiter = 100_000 , tol = 1e-100 ),
251
+ compile_kwargs = {"mode" : mode },
252
+ gradient_backend = gradient_backend ,
226
253
)
227
254
228
255
np .testing .assert_allclose (np .mean (idata .posterior .mu , axis = 1 ), np .full ((2 ,), 3 ), atol = 0.1 )
0 commit comments