diff --git a/pymc_extras/inference/find_map.py b/pymc_extras/inference/find_map.py index f6eacaa61..a4d664789 100644 --- a/pymc_extras/inference/find_map.py +++ b/pymc_extras/inference/find_map.py @@ -9,7 +9,7 @@ import pytensor import pytensor.tensor as pt -from better_optimize import minimize +from better_optimize import basinhopping, minimize from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.initial_point import make_initial_point_fn @@ -335,7 +335,7 @@ def scipy_optimize_funcs_from_loss( def find_MAP( - method: minimize_method, + method: minimize_method | Literal["basinhopping"], *, model: pm.Model | None = None, use_grad: bool | None = None, @@ -352,14 +352,17 @@ def find_MAP( **optimizer_kwargs, ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]: """ - Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize. + Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize. Parameters ---------- model : pm.Model The PyMC model to be fit. If None, the current model context is used. method : str - The optimization method to use. See scipy.optimize.minimize documentation for details. + The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP, + trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping. + + See scipy.optimize.minimize documentation for details. use_grad : bool | None, optional Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on the ``method``. @@ -387,7 +390,9 @@ def find_MAP( compile_kwargs: dict, optional Additional options to pass to the ``pytensor.function`` function when compiling loss functions. **optimizer_kwargs - Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. + Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless + ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``, + ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details. Returns ------- @@ -413,6 +418,18 @@ def find_MAP( initial_params = DictToArrayBijection.map( {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} ) + + do_basinhopping = method == "basinhopping" + minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {}) + + if do_basinhopping: + # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need + # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default + # if one isn't provided. + + method = minimizer_kwargs.pop("method", "L-BFGS-B") + minimizer_kwargs["method"] = method + use_grad, use_hess, use_hessp = set_optimizer_function_defaults( method, use_grad, use_hess, use_hessp ) @@ -431,17 +448,37 @@ def find_MAP( args = optimizer_kwargs.pop("args", None) # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument - # if so. That is why it is not set here, regardless of user settings. - optimizer_result = minimize( - f=f_logp, - x0=cast(np.ndarray[float], initial_params.data), - args=args, - hess=f_hess, - hessp=f_hessp, - progressbar=progressbar, - method=method, - **optimizer_kwargs, - ) + # if so. That is why the jac argument is not passed here in either branch. + + if do_basinhopping: + if "args" not in minimizer_kwargs: + minimizer_kwargs["args"] = args + if "hess" not in minimizer_kwargs: + minimizer_kwargs["hess"] = f_hess + if "hessp" not in minimizer_kwargs: + minimizer_kwargs["hessp"] = f_hessp + if "method" not in minimizer_kwargs: + minimizer_kwargs["method"] = method + + optimizer_result = basinhopping( + func=f_logp, + x0=cast(np.ndarray[float], initial_params.data), + progressbar=progressbar, + minimizer_kwargs=minimizer_kwargs, + **optimizer_kwargs, + ) + + else: + optimizer_result = minimize( + f=f_logp, + x0=cast(np.ndarray[float], initial_params.data), + args=args, + hess=f_hess, + hessp=f_hessp, + progressbar=progressbar, + method=method, + **optimizer_kwargs, + ) raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index bc35d926e..d86f1fee5 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -416,7 +416,7 @@ def sample_laplace_posterior( def fit_laplace( - optimize_method: minimize_method = "BFGS", + optimize_method: minimize_method | Literal["basinhopping"] = "BFGS", *, model: pm.Model | None = None, use_grad: bool | None = None, @@ -449,8 +449,11 @@ def fit_laplace( ---------- model : pm.Model The PyMC model to be fit. If None, the current model context is used. - optimize_method : str - The optimization method to use. See scipy.optimize.minimize documentation for details. + method : str + The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP, + trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping. + + See scipy.optimize.minimize documentation for details. use_grad : bool | None, optional Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on the ``method``. @@ -500,10 +503,10 @@ def fit_laplace( diag_jitter: float | None A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. If None, no jitter is added. Default is 1e-8. - optimizer_kwargs: dict, optional - Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for - details. Arguments that are typically passed via ``options`` will be automatically extracted without the need - to use a nested dictionary. + optimizer_kwargs + Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless + ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``, + ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details. compile_kwargs: dict, optional Additional keyword arguments to pass to pytensor.function. diff --git a/tests/test_find_map.py b/tests/test_find_map.py index adb081eea..f5aa549c7 100644 --- a/tests/test_find_map.py +++ b/tests/test_find_map.py @@ -124,3 +124,35 @@ def test_JAX_map_shared_variables(): assert np.isclose(mu_hat, 3, atol=0.5) assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) + + +@pytest.mark.parametrize( + "method, use_grad, use_hess, use_hessp", + [ + ("nelder-mead", False, False, False), + ("L-BFGS-B", True, False, False), + ("trust-exact", True, True, False), + ("trust-ncg", True, False, True), + ], +) +def test_find_MAP_basinhopping(method, use_grad, use_hess, use_hessp, rng): + with pm.Model() as m: + mu = pm.Normal("mu") + sigma = pm.Exponential("sigma", 1) + pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) + + optimized_point = find_MAP( + method="basinhopping", + use_grad=use_grad, + use_hess=use_hess, + use_hessp=use_hessp, + progressbar=False, + gradient_backend="pytensor", + compile_kwargs={"mode": "JAX"}, + minimizer_kwargs=dict(method=method), + ) + + mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] + + assert np.isclose(mu_hat, 3, atol=0.5) + assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)