Skip to content

Commit ab52503

Browse files
Bugfixes and address comments by @theorashid from on PR pymc-devs#385
1 parent ea8a926 commit ab52503

File tree

4 files changed

+105
-54
lines changed

4 files changed

+105
-54
lines changed

pymc_extras/inference/find_map.py

+31-15
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,29 @@
3030
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
3131
method_info = MINIMIZE_MODE_KWARGS[method].copy()
3232

33-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
34-
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
35-
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
36-
3733
if use_hess and use_hessp:
34+
_log.warning(
35+
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36+
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37+
'Setting "use_hess" to False.'
38+
)
3839
use_hess = False
3940

41+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
42+
43+
if use_hessp is not None and use_hess is None:
44+
use_hess = not use_hessp
45+
46+
elif use_hess is not None and use_hessp is None:
47+
use_hessp = not use_hess
48+
49+
elif use_hessp is None and use_hess is None:
50+
use_hessp = method_info["uses_hessp"]
51+
use_hess = method_info["uses_hess"]
52+
if use_hessp and use_hess:
53+
# If a method could use either hess or hessp, we default to using hessp
54+
use_hess = False
55+
4056
return use_grad, use_hess, use_hessp
4157

4258

@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
5975
The nearest positive semi-definite matrix to the input matrix.
6076
"""
6177
C = (A + A.T) / 2
62-
eigval, eigvec = np.linalg.eig(C)
78+
eigval, eigvec = np.linalg.eigh(C)
6379
eigval[eigval < 0] = 0
6480

6581
return eigvec @ np.diag(eigval) @ eigvec.T
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97113
return f_untransform(posterior_draws)
98114

99115

100-
def _compile_jax_gradients(
116+
def _compile_grad_and_hess_to_jax(
101117
f_loss: Function, use_hess: bool, use_hessp: bool
102118
) -> tuple[Callable | None, Callable | None]:
103119
"""
@@ -152,7 +168,7 @@ def f_hess_jax(x):
152168
return f_loss_and_grad, f_hess, f_hessp
153169

154170

155-
def _compile_functions(
171+
def _compile_functions_for_scipy_optimize(
156172
loss: TensorVariable,
157173
inputs: list[TensorVariable],
158174
compute_grad: bool,
@@ -177,7 +193,7 @@ def _compile_functions(
177193
compute_hessp: bool
178194
Whether to compile a function that computes the Hessian-vector product of the loss function.
179195
compile_kwargs: dict, optional
180-
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
196+
Additional keyword arguments to pass to the ``pm.compile`` function.
181197
182198
Returns
183199
-------
@@ -193,19 +209,19 @@ def _compile_functions(
193209
if compute_grad:
194210
grads = pytensor.gradient.grad(loss, inputs)
195211
grad = pt.concatenate([grad.ravel() for grad in grads])
196-
f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
212+
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
197213
else:
198-
f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
214+
f_loss = pm.compile(inputs, loss, **compile_kwargs)
199215
return [f_loss]
200216

201217
if compute_hess:
202218
hess = pytensor.gradient.jacobian(grad, inputs)[0]
203-
f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
219+
f_hess = pm.compile(inputs, hess, **compile_kwargs)
204220

205221
if compute_hessp:
206222
p = pt.tensor("p", shape=inputs[0].type.shape)
207223
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
208-
f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
224+
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
209225

210226
return [f_loss_and_grad, f_hess, f_hessp]
211227

@@ -240,7 +256,7 @@ def scipy_optimize_funcs_from_loss(
240256
gradient_backend: str, default "pytensor"
241257
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242258
compile_kwargs:
243-
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
259+
Additional keyword arguments to pass to the ``pm.compile`` function.
244260
245261
Returns
246262
-------
@@ -285,7 +301,7 @@ def scipy_optimize_funcs_from_loss(
285301
compute_hess = use_hess and not use_jax_gradients
286302
compute_hessp = use_hessp and not use_jax_gradients
287303

288-
funcs = _compile_functions(
304+
funcs = _compile_functions_for_scipy_optimize(
289305
loss=loss,
290306
inputs=[flat_input],
291307
compute_grad=compute_grad,
@@ -301,7 +317,7 @@ def scipy_optimize_funcs_from_loss(
301317

302318
if use_jax_gradients:
303319
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304-
f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
320+
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
305321

306322
return f_loss, f_hess, f_hessp
307323

pymc_extras/inference/laplace.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def add_data_to_inferencedata(
231231
return idata
232232

233233

234-
def fit_mvn_to_MAP(
234+
def fit_mvn_at_MAP(
235235
optimized_point: dict[str, np.ndarray],
236236
model: pm.Model | None = None,
237237
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
@@ -344,8 +344,10 @@ def sample_laplace_posterior(
344344
345345
Parameters
346346
----------
347-
mu
348-
H_inv
347+
mu: RaveledVars
348+
The MAP estimate of the model parameters.
349+
H_inv: np.ndarray
350+
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
349351
model : Model
350352
A PyMC model
351353
chains : int
@@ -384,9 +386,7 @@ def sample_laplace_posterior(
384386
constrained_rvs, replace={unconstrained_vector: batched_values}
385387
)
386388

387-
f_constrain = pm.compile_pymc(
388-
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
389-
)
389+
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
390390
posterior_draws = f_constrain(posterior_draws)
391391

392392
else:
@@ -472,15 +472,17 @@ def fit_laplace(
472472
and 1).
473473
474474
.. warning::
475-
This argumnet should be considered highly experimental. It has not been verified if this method produces
475+
This argument should be considered highly experimental. It has not been verified if this method produces
476476
valid draws from the posterior. **Use at your own risk**.
477477
478478
gradient_backend: str, default "pytensor"
479479
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
480480
chains: int, default: 2
481-
The number of sampling chains running in parallel.
481+
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
482+
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
483+
compatible with the ArviZ library.
482484
draws: int, default: 500
483-
The number of samples to draw from the approximated posterior.
485+
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
484486
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
485487
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
486488
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -547,11 +549,12 @@ def fit_laplace(
547549
**optimizer_kwargs,
548550
)
549551

550-
mu, H_inv = fit_mvn_to_MAP(
552+
mu, H_inv = fit_mvn_at_MAP(
551553
optimized_point=optimized_point,
552554
model=model,
553555
on_bad_cov=on_bad_cov,
554556
transform_samples=fit_in_unconstrained_space,
557+
gradient_backend=gradient_backend,
555558
zero_tol=zero_tol,
556559
diag_jitter=diag_jitter,
557560
compile_kwargs=compile_kwargs,

tests/test_find_map.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,28 @@ def compute_z(x):
5454

5555

5656
@pytest.mark.parametrize(
57-
"method, use_grad, use_hess",
57+
"method, use_grad, use_hess, use_hessp",
5858
[
59-
("nelder-mead", False, False),
60-
("powell", False, False),
61-
("CG", True, False),
62-
("BFGS", True, False),
63-
("L-BFGS-B", True, False),
64-
("TNC", True, False),
65-
("SLSQP", True, False),
66-
("dogleg", True, True),
67-
("trust-ncg", True, True),
68-
("trust-exact", True, True),
69-
("trust-krylov", True, True),
70-
("trust-constr", True, True),
59+
("nelder-mead", False, False, False),
60+
("powell", False, False, False),
61+
("CG", True, False, False),
62+
("BFGS", True, False, False),
63+
("L-BFGS-B", True, False, False),
64+
("TNC", True, False, False),
65+
("SLSQP", True, False, False),
66+
("dogleg", True, True, False),
67+
("Newton-CG", True, True, False),
68+
("Newton-CG", True, False, True),
69+
("trust-ncg", True, True, False),
70+
("trust-ncg", True, False, True),
71+
("trust-exact", True, True, False),
72+
("trust-krylov", True, True, False),
73+
("trust-krylov", True, False, True),
74+
("trust-constr", True, True, False),
7175
],
7276
)
7377
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
74-
def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
78+
def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
7579
extra_kwargs = {}
7680
if method == "dogleg":
7781
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
@@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend,
8892
**extra_kwargs,
8993
use_grad=use_grad,
9094
use_hess=use_hess,
95+
use_hessp=use_hessp,
9196
progressbar=False,
9297
gradient_backend=gradient_backend,
9398
compile_kwargs={"mode": "JAX"},

tests/test_laplace.py

+42-15
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
import pymc_extras as pmx
2121

22-
from pymc_extras.inference.find_map import find_MAP
22+
from pymc_extras.inference.find_map import GradientBackend, find_MAP
2323
from pymc_extras.inference.laplace import (
2424
fit_laplace,
25-
fit_mvn_to_MAP,
25+
fit_mvn_at_MAP,
2626
sample_laplace_posterior,
2727
)
2828

@@ -37,7 +37,11 @@ def rng():
3737
"ignore:hessian will stop negating the output in a future version of PyMC.\n"
3838
+ "To suppress this warning set `negate_output=False`:FutureWarning",
3939
)
40-
def test_laplace():
40+
@pytest.mark.parametrize(
41+
"mode, gradient_backend",
42+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
43+
)
44+
def test_laplace(mode, gradient_backend: GradientBackend):
4145
# Example originates from Bayesian Data Analyses, 3rd Edition
4246
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
4347
# Aki Vehtari, and Donald Rubin.
@@ -55,7 +59,13 @@ def test_laplace():
5559
vars = [mu, logsigma]
5660

5761
idata = pmx.fit(
58-
method="laplace", optimize_method="trust-ncg", draws=draws, random_seed=173300, chains=1
62+
method="laplace",
63+
optimize_method="trust-ncg",
64+
draws=draws,
65+
random_seed=173300,
66+
chains=1,
67+
compile_kwargs={"mode": mode},
68+
gradient_backend=gradient_backend,
5969
)
6070

6171
assert idata.posterior["mu"].shape == (1, draws)
@@ -71,7 +81,11 @@ def test_laplace():
7181
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
7282

7383

74-
def test_laplace_only_fit():
84+
@pytest.mark.parametrize(
85+
"mode, gradient_backend",
86+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
87+
)
88+
def test_laplace_only_fit(mode, gradient_backend: GradientBackend):
7589
# Example originates from Bayesian Data Analyses, 3rd Edition
7690
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
7791
# Aki Vehtari, and Donald Rubin.
@@ -90,8 +104,8 @@ def test_laplace_only_fit():
90104
method="laplace",
91105
optimize_method="BFGS",
92106
progressbar=True,
93-
gradient_backend="jax",
94-
compile_kwargs={"mode": "JAX"},
107+
gradient_backend=gradient_backend,
108+
compile_kwargs={"mode": mode},
95109
optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100),
96110
random_seed=173300,
97111
)
@@ -111,8 +125,11 @@ def test_laplace_only_fit():
111125
[True, False],
112126
ids=["transformed", "untransformed"],
113127
)
114-
@pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"])
115-
def test_fit_laplace_coords(rng, transform_samples, mode):
128+
@pytest.mark.parametrize(
129+
"mode, gradient_backend",
130+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
131+
)
132+
def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: GradientBackend):
116133
coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)}
117134
with pm.Model(coords=coords) as model:
118135
mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"])
@@ -131,13 +148,13 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
131148
use_hessp=True,
132149
progressbar=False,
133150
compile_kwargs=dict(mode=mode),
134-
gradient_backend="jax" if mode == "JAX" else "pytensor",
151+
gradient_backend=gradient_backend,
135152
)
136153

137154
for value in optimized_point.values():
138155
assert value.shape == (3,)
139156

140-
mu, H_inv = fit_mvn_to_MAP(
157+
mu, H_inv = fit_mvn_at_MAP(
141158
optimized_point=optimized_point,
142159
model=model,
143160
transform_samples=transform_samples,
@@ -163,7 +180,11 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
163180
]
164181

165182

166-
def test_fit_laplace_ragged_coords(rng):
183+
@pytest.mark.parametrize(
184+
"mode, gradient_backend",
185+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
186+
)
187+
def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng):
167188
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
168189
with pm.Model(coords=coords) as ragged_dim_model:
169190
X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"])
@@ -188,8 +209,8 @@ def test_fit_laplace_ragged_coords(rng):
188209
progressbar=False,
189210
use_grad=True,
190211
use_hessp=True,
191-
gradient_backend="jax",
192-
compile_kwargs={"mode": "JAX"},
212+
gradient_backend=gradient_backend,
213+
compile_kwargs={"mode": mode},
193214
)
194215

195216
assert idata["posterior"].beta.shape[-2:] == (3, 2)
@@ -206,7 +227,11 @@ def test_fit_laplace_ragged_coords(rng):
206227
[True, False],
207228
ids=["transformed", "untransformed"],
208229
)
209-
def test_fit_laplace(fit_in_unconstrained_space):
230+
@pytest.mark.parametrize(
231+
"mode, gradient_backend",
232+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
233+
)
234+
def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: GradientBackend):
210235
with pm.Model() as simp_model:
211236
mu = pm.Normal("mu", mu=3, sigma=0.5)
212237
sigma = pm.Exponential("sigma", 1)
@@ -223,6 +248,8 @@ def test_fit_laplace(fit_in_unconstrained_space):
223248
use_hessp=True,
224249
fit_in_unconstrained_space=fit_in_unconstrained_space,
225250
optimizer_kwargs=dict(maxiter=100_000, tol=1e-100),
251+
compile_kwargs={"mode": mode},
252+
gradient_backend=gradient_backend,
226253
)
227254

228255
np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1)

0 commit comments

Comments
 (0)