9
9
import pytensor
10
10
import pytensor .tensor as pt
11
11
12
- from better_optimize import minimize
12
+ from better_optimize import basinhopping , minimize
13
13
from better_optimize .constants import MINIMIZE_MODE_KWARGS , minimize_method
14
14
from pymc .blocking import DictToArrayBijection , RaveledVars
15
15
from pymc .initial_point import make_initial_point_fn
@@ -335,7 +335,7 @@ def scipy_optimize_funcs_from_loss(
335
335
336
336
337
337
def find_MAP (
338
- method : minimize_method ,
338
+ method : minimize_method | Literal [ "basinhopping" ] ,
339
339
* ,
340
340
model : pm .Model | None = None ,
341
341
use_grad : bool | None = None ,
@@ -352,14 +352,17 @@ def find_MAP(
352
352
** optimizer_kwargs ,
353
353
) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], OptimizeResult ]:
354
354
"""
355
- Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize .
355
+ Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize .
356
356
357
357
Parameters
358
358
----------
359
359
model : pm.Model
360
360
The PyMC model to be fit. If None, the current model context is used.
361
361
method : str
362
- The optimization method to use. See scipy.optimize.minimize documentation for details.
362
+ The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
363
+ trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
364
+
365
+ See scipy.optimize.minimize documentation for details.
363
366
use_grad : bool | None, optional
364
367
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
365
368
the ``method``.
@@ -387,7 +390,9 @@ def find_MAP(
387
390
compile_kwargs: dict, optional
388
391
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
389
392
**optimizer_kwargs
390
- Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
393
+ Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
394
+ ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
395
+ ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
391
396
392
397
Returns
393
398
-------
@@ -413,6 +418,18 @@ def find_MAP(
413
418
initial_params = DictToArrayBijection .map (
414
419
{var_name : value for var_name , value in start_dict .items () if var_name in vars_dict }
415
420
)
421
+
422
+ do_basinhopping = method == "basinhopping"
423
+ minimizer_kwargs = optimizer_kwargs .pop ("minimizer_kwargs" , {})
424
+
425
+ if do_basinhopping :
426
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
427
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
428
+ # if one isn't provided.
429
+
430
+ method = minimizer_kwargs .pop ("method" , "L-BFGS-B" )
431
+ minimizer_kwargs ["method" ] = method
432
+
416
433
use_grad , use_hess , use_hessp = set_optimizer_function_defaults (
417
434
method , use_grad , use_hess , use_hessp
418
435
)
@@ -431,17 +448,37 @@ def find_MAP(
431
448
args = optimizer_kwargs .pop ("args" , None )
432
449
433
450
# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
434
- # if so. That is why it is not set here, regardless of user settings.
435
- optimizer_result = minimize (
436
- f = f_logp ,
437
- x0 = cast (np .ndarray [float ], initial_params .data ),
438
- args = args ,
439
- hess = f_hess ,
440
- hessp = f_hessp ,
441
- progressbar = progressbar ,
442
- method = method ,
443
- ** optimizer_kwargs ,
444
- )
451
+ # if so. That is why the jac argument is not passed here in either branch.
452
+
453
+ if do_basinhopping :
454
+ if "args" not in minimizer_kwargs :
455
+ minimizer_kwargs ["args" ] = args
456
+ if "hess" not in minimizer_kwargs :
457
+ minimizer_kwargs ["hess" ] = f_hess
458
+ if "hessp" not in minimizer_kwargs :
459
+ minimizer_kwargs ["hessp" ] = f_hessp
460
+ if "method" not in minimizer_kwargs :
461
+ minimizer_kwargs ["method" ] = method
462
+
463
+ optimizer_result = basinhopping (
464
+ func = f_logp ,
465
+ x0 = cast (np .ndarray [float ], initial_params .data ),
466
+ progressbar = progressbar ,
467
+ minimizer_kwargs = minimizer_kwargs ,
468
+ ** optimizer_kwargs ,
469
+ )
470
+
471
+ else :
472
+ optimizer_result = minimize (
473
+ f = f_logp ,
474
+ x0 = cast (np .ndarray [float ], initial_params .data ),
475
+ args = args ,
476
+ hess = f_hess ,
477
+ hessp = f_hessp ,
478
+ progressbar = progressbar ,
479
+ method = method ,
480
+ ** optimizer_kwargs ,
481
+ )
445
482
446
483
raveled_optimized = RaveledVars (optimizer_result .x , initial_params .point_map_info )
447
484
unobserved_vars = get_default_varnames (model .unobserved_value_vars , include_transformed )
0 commit comments