Skip to content

Commit f2656f5

Browse files
Allow method="basinhopping" in find_MAP and fit_laplace (#467)
1 parent 413a4cb commit f2656f5

File tree

3 files changed

+95
-23
lines changed

3 files changed

+95
-23
lines changed

pymc_extras/inference/find_map.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytensor
1010
import pytensor.tensor as pt
1111

12-
from better_optimize import minimize
12+
from better_optimize import basinhopping, minimize
1313
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
1414
from pymc.blocking import DictToArrayBijection, RaveledVars
1515
from pymc.initial_point import make_initial_point_fn
@@ -335,7 +335,7 @@ def scipy_optimize_funcs_from_loss(
335335

336336

337337
def find_MAP(
338-
method: minimize_method,
338+
method: minimize_method | Literal["basinhopping"],
339339
*,
340340
model: pm.Model | None = None,
341341
use_grad: bool | None = None,
@@ -352,14 +352,17 @@ def find_MAP(
352352
**optimizer_kwargs,
353353
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
354354
"""
355-
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize.
355+
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.
356356
357357
Parameters
358358
----------
359359
model : pm.Model
360360
The PyMC model to be fit. If None, the current model context is used.
361361
method : str
362-
The optimization method to use. See scipy.optimize.minimize documentation for details.
362+
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
363+
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
364+
365+
See scipy.optimize.minimize documentation for details.
363366
use_grad : bool | None, optional
364367
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
365368
the ``method``.
@@ -387,7 +390,9 @@ def find_MAP(
387390
compile_kwargs: dict, optional
388391
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
389392
**optimizer_kwargs
390-
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
393+
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
394+
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
395+
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
391396
392397
Returns
393398
-------
@@ -413,6 +418,18 @@ def find_MAP(
413418
initial_params = DictToArrayBijection.map(
414419
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
415420
)
421+
422+
do_basinhopping = method == "basinhopping"
423+
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
424+
425+
if do_basinhopping:
426+
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
427+
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
428+
# if one isn't provided.
429+
430+
method = minimizer_kwargs.pop("method", "L-BFGS-B")
431+
minimizer_kwargs["method"] = method
432+
416433
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
417434
method, use_grad, use_hess, use_hessp
418435
)
@@ -431,17 +448,37 @@ def find_MAP(
431448
args = optimizer_kwargs.pop("args", None)
432449

433450
# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
434-
# if so. That is why it is not set here, regardless of user settings.
435-
optimizer_result = minimize(
436-
f=f_logp,
437-
x0=cast(np.ndarray[float], initial_params.data),
438-
args=args,
439-
hess=f_hess,
440-
hessp=f_hessp,
441-
progressbar=progressbar,
442-
method=method,
443-
**optimizer_kwargs,
444-
)
451+
# if so. That is why the jac argument is not passed here in either branch.
452+
453+
if do_basinhopping:
454+
if "args" not in minimizer_kwargs:
455+
minimizer_kwargs["args"] = args
456+
if "hess" not in minimizer_kwargs:
457+
minimizer_kwargs["hess"] = f_hess
458+
if "hessp" not in minimizer_kwargs:
459+
minimizer_kwargs["hessp"] = f_hessp
460+
if "method" not in minimizer_kwargs:
461+
minimizer_kwargs["method"] = method
462+
463+
optimizer_result = basinhopping(
464+
func=f_logp,
465+
x0=cast(np.ndarray[float], initial_params.data),
466+
progressbar=progressbar,
467+
minimizer_kwargs=minimizer_kwargs,
468+
**optimizer_kwargs,
469+
)
470+
471+
else:
472+
optimizer_result = minimize(
473+
f=f_logp,
474+
x0=cast(np.ndarray[float], initial_params.data),
475+
args=args,
476+
hess=f_hess,
477+
hessp=f_hessp,
478+
progressbar=progressbar,
479+
method=method,
480+
**optimizer_kwargs,
481+
)
445482

446483
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
447484
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)

pymc_extras/inference/laplace.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def sample_laplace_posterior(
416416

417417

418418
def fit_laplace(
419-
optimize_method: minimize_method = "BFGS",
419+
optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
420420
*,
421421
model: pm.Model | None = None,
422422
use_grad: bool | None = None,
@@ -449,8 +449,11 @@ def fit_laplace(
449449
----------
450450
model : pm.Model
451451
The PyMC model to be fit. If None, the current model context is used.
452-
optimize_method : str
453-
The optimization method to use. See scipy.optimize.minimize documentation for details.
452+
method : str
453+
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
454+
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
455+
456+
See scipy.optimize.minimize documentation for details.
454457
use_grad : bool | None, optional
455458
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
456459
the ``method``.
@@ -500,10 +503,10 @@ def fit_laplace(
500503
diag_jitter: float | None
501504
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
502505
If None, no jitter is added. Default is 1e-8.
503-
optimizer_kwargs: dict, optional
504-
Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for
505-
details. Arguments that are typically passed via ``options`` will be automatically extracted without the need
506-
to use a nested dictionary.
506+
optimizer_kwargs
507+
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
508+
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
509+
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
507510
compile_kwargs: dict, optional
508511
Additional keyword arguments to pass to pytensor.function.
509512

tests/test_find_map.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,35 @@ def test_JAX_map_shared_variables():
124124

125125
assert np.isclose(mu_hat, 3, atol=0.5)
126126
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)
127+
128+
129+
@pytest.mark.parametrize(
130+
"method, use_grad, use_hess, use_hessp",
131+
[
132+
("nelder-mead", False, False, False),
133+
("L-BFGS-B", True, False, False),
134+
("trust-exact", True, True, False),
135+
("trust-ncg", True, False, True),
136+
],
137+
)
138+
def test_find_MAP_basinhopping(method, use_grad, use_hess, use_hessp, rng):
139+
with pm.Model() as m:
140+
mu = pm.Normal("mu")
141+
sigma = pm.Exponential("sigma", 1)
142+
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100))
143+
144+
optimized_point = find_MAP(
145+
method="basinhopping",
146+
use_grad=use_grad,
147+
use_hess=use_hess,
148+
use_hessp=use_hessp,
149+
progressbar=False,
150+
gradient_backend="pytensor",
151+
compile_kwargs={"mode": "JAX"},
152+
minimizer_kwargs=dict(method=method),
153+
)
154+
155+
mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]
156+
157+
assert np.isclose(mu_hat, 3, atol=0.5)
158+
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)

0 commit comments

Comments
 (0)