Skip to content

Allow method="basinhopping" in find_MAP and fit_laplace #467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions pymc_extras/inference/find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytensor
import pytensor.tensor as pt

from better_optimize import minimize
from better_optimize import basinhopping, minimize
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.initial_point import make_initial_point_fn
Expand Down Expand Up @@ -335,7 +335,7 @@ def scipy_optimize_funcs_from_loss(


def find_MAP(
method: minimize_method,
method: minimize_method | Literal["basinhopping"],
*,
model: pm.Model | None = None,
use_grad: bool | None = None,
Expand All @@ -352,14 +352,17 @@ def find_MAP(
**optimizer_kwargs,
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
"""
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize.
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.

Parameters
----------
model : pm.Model
The PyMC model to be fit. If None, the current model context is used.
method : str
The optimization method to use. See scipy.optimize.minimize documentation for details.
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.

See scipy.optimize.minimize documentation for details.
use_grad : bool | None, optional
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
the ``method``.
Expand Down Expand Up @@ -387,7 +390,9 @@ def find_MAP(
compile_kwargs: dict, optional
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
**optimizer_kwargs
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.

Returns
-------
Expand All @@ -413,6 +418,18 @@ def find_MAP(
initial_params = DictToArrayBijection.map(
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
)

do_basinhopping = method == "basinhopping"
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})

if do_basinhopping:
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
# if one isn't provided.

method = minimizer_kwargs.pop("method", "L-BFGS-B")
minimizer_kwargs["method"] = method

use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
method, use_grad, use_hess, use_hessp
)
Expand All @@ -431,17 +448,37 @@ def find_MAP(
args = optimizer_kwargs.pop("args", None)

# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
# if so. That is why it is not set here, regardless of user settings.
optimizer_result = minimize(
f=f_logp,
x0=cast(np.ndarray[float], initial_params.data),
args=args,
hess=f_hess,
hessp=f_hessp,
progressbar=progressbar,
method=method,
**optimizer_kwargs,
)
# if so. That is why the jac argument is not passed here in either branch.

if do_basinhopping:
if "args" not in minimizer_kwargs:
minimizer_kwargs["args"] = args
if "hess" not in minimizer_kwargs:
minimizer_kwargs["hess"] = f_hess
if "hessp" not in minimizer_kwargs:
minimizer_kwargs["hessp"] = f_hessp
if "method" not in minimizer_kwargs:
minimizer_kwargs["method"] = method

optimizer_result = basinhopping(
func=f_logp,
x0=cast(np.ndarray[float], initial_params.data),
progressbar=progressbar,
minimizer_kwargs=minimizer_kwargs,
**optimizer_kwargs,
)

else:
optimizer_result = minimize(
f=f_logp,
x0=cast(np.ndarray[float], initial_params.data),
args=args,
hess=f_hess,
hessp=f_hessp,
progressbar=progressbar,
method=method,
**optimizer_kwargs,
)

raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
Expand Down
17 changes: 10 additions & 7 deletions pymc_extras/inference/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def sample_laplace_posterior(


def fit_laplace(
optimize_method: minimize_method = "BFGS",
optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
*,
model: pm.Model | None = None,
use_grad: bool | None = None,
Expand Down Expand Up @@ -449,8 +449,11 @@ def fit_laplace(
----------
model : pm.Model
The PyMC model to be fit. If None, the current model context is used.
optimize_method : str
The optimization method to use. See scipy.optimize.minimize documentation for details.
method : str
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.

See scipy.optimize.minimize documentation for details.
use_grad : bool | None, optional
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
the ``method``.
Expand Down Expand Up @@ -500,10 +503,10 @@ def fit_laplace(
diag_jitter: float | None
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
If None, no jitter is added. Default is 1e-8.
optimizer_kwargs: dict, optional
Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for
details. Arguments that are typically passed via ``options`` will be automatically extracted without the need
to use a nested dictionary.
optimizer_kwargs
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
compile_kwargs: dict, optional
Additional keyword arguments to pass to pytensor.function.

Expand Down
32 changes: 32 additions & 0 deletions tests/test_find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,35 @@ def test_JAX_map_shared_variables():

assert np.isclose(mu_hat, 3, atol=0.5)
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)


@pytest.mark.parametrize(
"method, use_grad, use_hess, use_hessp",
[
("nelder-mead", False, False, False),
("L-BFGS-B", True, False, False),
("trust-exact", True, True, False),
("trust-ncg", True, False, True),
],
)
def test_find_MAP_basinhopping(method, use_grad, use_hess, use_hessp, rng):
with pm.Model() as m:
mu = pm.Normal("mu")
sigma = pm.Exponential("sigma", 1)
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100))

optimized_point = find_MAP(
method="basinhopping",
use_grad=use_grad,
use_hess=use_hess,
use_hessp=use_hessp,
progressbar=False,
gradient_backend="pytensor",
compile_kwargs={"mode": "JAX"},
minimizer_kwargs=dict(method=method),
)

mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]

assert np.isclose(mu_hat, 3, atol=0.5)
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)
Loading