Skip to content

Commit 4ddd7f8

Browse files
Bugfixes and address comments by @theorashid from on PR #385
Check for `jax` installation before any computation if `gradient_backend = 'jax'`
1 parent ea8a926 commit 4ddd7f8

File tree

4 files changed

+114
-55
lines changed

4 files changed

+114
-55
lines changed

pymc_extras/inference/find_map.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22

33
from collections.abc import Callable
4+
from importlib.util import find_spec
45
from typing import Literal, cast, get_args
56

6-
import jax
77
import numpy as np
88
import pymc as pm
99
import pytensor
@@ -30,13 +30,29 @@
3030
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
3131
method_info = MINIMIZE_MODE_KWARGS[method].copy()
3232

33-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
34-
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
35-
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
36-
3733
if use_hess and use_hessp:
34+
_log.warning(
35+
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36+
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37+
'Setting "use_hess" to False.'
38+
)
3839
use_hess = False
3940

41+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
42+
43+
if use_hessp is not None and use_hess is None:
44+
use_hess = not use_hessp
45+
46+
elif use_hess is not None and use_hessp is None:
47+
use_hessp = not use_hess
48+
49+
elif use_hessp is None and use_hess is None:
50+
use_hessp = method_info["uses_hessp"]
51+
use_hess = method_info["uses_hess"]
52+
if use_hessp and use_hess:
53+
# If a method could use either hess or hessp, we default to using hessp
54+
use_hess = False
55+
4056
return use_grad, use_hess, use_hessp
4157

4258

@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
5975
The nearest positive semi-definite matrix to the input matrix.
6076
"""
6177
C = (A + A.T) / 2
62-
eigval, eigvec = np.linalg.eig(C)
78+
eigval, eigvec = np.linalg.eigh(C)
6379
eigval[eigval < 0] = 0
6480

6581
return eigvec @ np.diag(eigval) @ eigvec.T
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97113
return f_untransform(posterior_draws)
98114

99115

100-
def _compile_jax_gradients(
116+
def _compile_grad_and_hess_to_jax(
101117
f_loss: Function, use_hess: bool, use_hessp: bool
102118
) -> tuple[Callable | None, Callable | None]:
103119
"""
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
122138
f_hessp: Callable | None
123139
The compiled hessian-vector product function, or None if use_hessp is False.
124140
"""
141+
import jax
142+
125143
f_hess = None
126144
f_hessp = None
127145

@@ -152,7 +170,7 @@ def f_hess_jax(x):
152170
return f_loss_and_grad, f_hess, f_hessp
153171

154172

155-
def _compile_functions(
173+
def _compile_functions_for_scipy_optimize(
156174
loss: TensorVariable,
157175
inputs: list[TensorVariable],
158176
compute_grad: bool,
@@ -177,7 +195,7 @@ def _compile_functions(
177195
compute_hessp: bool
178196
Whether to compile a function that computes the Hessian-vector product of the loss function.
179197
compile_kwargs: dict, optional
180-
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
198+
Additional keyword arguments to pass to the ``pm.compile`` function.
181199
182200
Returns
183201
-------
@@ -193,19 +211,19 @@ def _compile_functions(
193211
if compute_grad:
194212
grads = pytensor.gradient.grad(loss, inputs)
195213
grad = pt.concatenate([grad.ravel() for grad in grads])
196-
f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
214+
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
197215
else:
198-
f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
216+
f_loss = pm.compile(inputs, loss, **compile_kwargs)
199217
return [f_loss]
200218

201219
if compute_hess:
202220
hess = pytensor.gradient.jacobian(grad, inputs)[0]
203-
f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
221+
f_hess = pm.compile(inputs, hess, **compile_kwargs)
204222

205223
if compute_hessp:
206224
p = pt.tensor("p", shape=inputs[0].type.shape)
207225
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
208-
f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
226+
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
209227

210228
return [f_loss_and_grad, f_hess, f_hessp]
211229

@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
240258
gradient_backend: str, default "pytensor"
241259
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242260
compile_kwargs:
243-
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
261+
Additional keyword arguments to pass to the ``pm.compile`` function.
244262
245263
Returns
246264
-------
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
265283
)
266284

267285
use_jax_gradients = (gradient_backend == "jax") and use_grad
286+
if use_jax_gradients and not find_spec("jax"):
287+
raise ImportError("JAX must be installed to use JAX gradients")
268288

269289
mode = compile_kwargs.get("mode", None)
270290
if mode is None and use_jax_gradients:
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
285305
compute_hess = use_hess and not use_jax_gradients
286306
compute_hessp = use_hessp and not use_jax_gradients
287307

288-
funcs = _compile_functions(
308+
funcs = _compile_functions_for_scipy_optimize(
289309
loss=loss,
290310
inputs=[flat_input],
291311
compute_grad=compute_grad,
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
301321

302322
if use_jax_gradients:
303323
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304-
f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
324+
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
305325

306326
return f_loss, f_hess, f_hessp
307327

pymc_extras/inference/laplace.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717

1818
from functools import reduce
19+
from importlib.util import find_spec
1920
from itertools import product
2021
from typing import Literal
2122

@@ -231,7 +232,7 @@ def add_data_to_inferencedata(
231232
return idata
232233

233234

234-
def fit_mvn_to_MAP(
235+
def fit_mvn_at_MAP(
235236
optimized_point: dict[str, np.ndarray],
236237
model: pm.Model | None = None,
237238
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
@@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
276277
inverse_hessian: np.ndarray
277278
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
278279
"""
280+
if gradient_backend == "jax" and not find_spec("jax"):
281+
raise ImportError("JAX must be installed to use JAX gradients")
282+
279283
model = pm.modelcontext(model)
280284
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
281285
frozen_model = freeze_dims_and_data(model)
@@ -344,8 +348,10 @@ def sample_laplace_posterior(
344348
345349
Parameters
346350
----------
347-
mu
348-
H_inv
351+
mu: RaveledVars
352+
The MAP estimate of the model parameters.
353+
H_inv: np.ndarray
354+
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
349355
model : Model
350356
A PyMC model
351357
chains : int
@@ -384,9 +390,7 @@ def sample_laplace_posterior(
384390
constrained_rvs, replace={unconstrained_vector: batched_values}
385391
)
386392

387-
f_constrain = pm.compile_pymc(
388-
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
389-
)
393+
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
390394
posterior_draws = f_constrain(posterior_draws)
391395

392396
else:
@@ -472,15 +476,17 @@ def fit_laplace(
472476
and 1).
473477
474478
.. warning::
475-
This argumnet should be considered highly experimental. It has not been verified if this method produces
479+
This argument should be considered highly experimental. It has not been verified if this method produces
476480
valid draws from the posterior. **Use at your own risk**.
477481
478482
gradient_backend: str, default "pytensor"
479483
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
480484
chains: int, default: 2
481-
The number of sampling chains running in parallel.
485+
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
486+
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
487+
compatible with the ArviZ library.
482488
draws: int, default: 500
483-
The number of samples to draw from the approximated posterior.
489+
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
484490
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
485491
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
486492
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -547,11 +553,12 @@ def fit_laplace(
547553
**optimizer_kwargs,
548554
)
549555

550-
mu, H_inv = fit_mvn_to_MAP(
556+
mu, H_inv = fit_mvn_at_MAP(
551557
optimized_point=optimized_point,
552558
model=model,
553559
on_bad_cov=on_bad_cov,
554560
transform_samples=fit_in_unconstrained_space,
561+
gradient_backend=gradient_backend,
555562
zero_tol=zero_tol,
556563
diag_jitter=diag_jitter,
557564
compile_kwargs=compile_kwargs,

tests/test_find_map.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,28 @@ def compute_z(x):
5454

5555

5656
@pytest.mark.parametrize(
57-
"method, use_grad, use_hess",
57+
"method, use_grad, use_hess, use_hessp",
5858
[
59-
("nelder-mead", False, False),
60-
("powell", False, False),
61-
("CG", True, False),
62-
("BFGS", True, False),
63-
("L-BFGS-B", True, False),
64-
("TNC", True, False),
65-
("SLSQP", True, False),
66-
("dogleg", True, True),
67-
("trust-ncg", True, True),
68-
("trust-exact", True, True),
69-
("trust-krylov", True, True),
70-
("trust-constr", True, True),
59+
("nelder-mead", False, False, False),
60+
("powell", False, False, False),
61+
("CG", True, False, False),
62+
("BFGS", True, False, False),
63+
("L-BFGS-B", True, False, False),
64+
("TNC", True, False, False),
65+
("SLSQP", True, False, False),
66+
("dogleg", True, True, False),
67+
("Newton-CG", True, True, False),
68+
("Newton-CG", True, False, True),
69+
("trust-ncg", True, True, False),
70+
("trust-ncg", True, False, True),
71+
("trust-exact", True, True, False),
72+
("trust-krylov", True, True, False),
73+
("trust-krylov", True, False, True),
74+
("trust-constr", True, True, False),
7175
],
7276
)
7377
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
74-
def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
78+
def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
7579
extra_kwargs = {}
7680
if method == "dogleg":
7781
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
@@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend,
8892
**extra_kwargs,
8993
use_grad=use_grad,
9094
use_hess=use_hess,
95+
use_hessp=use_hessp,
9196
progressbar=False,
9297
gradient_backend=gradient_backend,
9398
compile_kwargs={"mode": "JAX"},

0 commit comments

Comments
 (0)