24
24
import pandas as pd
25
25
import pymc as pm
26
26
import xarray as xr
27
+ from pymc .backends import NDArray
28
+ from pymc .backends .base import MultiTrace
27
29
from pymc .util import RandomState
28
30
29
31
# If scikit-learn is available, use its data validator
@@ -425,6 +427,7 @@ def fit(
425
427
self ,
426
428
X : pd .DataFrame ,
427
429
y : Optional [pd .Series ] = None ,
430
+ fit_method = "mcmc" ,
428
431
progressbar : bool = True ,
429
432
predictor_names : List [str ] = None ,
430
433
random_seed : RandomState = None ,
@@ -441,6 +444,8 @@ def fit(
441
444
The training input samples.
442
445
y : array-like if sklearn is available, otherwise array, shape (n_obs,)
443
446
The target values (real numbers).
447
+ fit_method : str
448
+ Which method to use to infer model parameters. One of ["mcmc", "MAP"].
444
449
progressbar : bool
445
450
Specifies whether the fit progressbar should be displayed
446
451
predictor_names: List[str] = None,
@@ -449,19 +454,14 @@ def fit(
449
454
random_seed : RandomState
450
455
Provides sampler with initial random seed for obtaining reproducible samples
451
456
**kwargs : Any
452
- Custom sampler settings can be provided in form of keyword arguments.
453
-
454
- Returns
455
- -------
456
- self : az.InferenceData
457
- returns inference data of the fitted model.
458
- Examples
459
- --------
460
- >>> model = MyModel()
461
- >>> idata = model.fit(data)
462
- Auto-assigning NUTS sampler...
463
- Initializing NUTS using jitter+adapt_diag...
457
+ Parameters to pass to the inference method. See `_fit_mcmc` or `_fit_MAP` for
458
+ method-specific parameters.
464
459
"""
460
+ available_methods = ["mcmc" , "MAP" ]
461
+ if fit_method not in available_methods :
462
+ raise ValueError (
463
+ f"Inference method { fit_method } not found. Choose one of { available_methods } ."
464
+ )
465
465
if predictor_names is None :
466
466
predictor_names = []
467
467
if y is None :
@@ -474,14 +474,74 @@ def fit(
474
474
sampler_config ["progressbar" ] = progressbar
475
475
sampler_config ["random_seed" ] = random_seed
476
476
sampler_config .update (** kwargs )
477
- self .idata = self .sample_model (** sampler_config )
477
+
478
+ if fit_method == "mcmc" :
479
+ self .idata = self .sample_model (** sampler_config )
480
+ elif fit_method == "MAP" :
481
+ self .idata = self ._fit_MAP (** sampler_config )
478
482
479
483
X_df = pd .DataFrame (X , columns = X .columns )
480
484
combined_data = pd .concat ([X_df , y ], axis = 1 )
481
485
assert all (combined_data .columns ), "All columns must have non-empty names"
482
486
self .idata .add_groups (fit_data = combined_data .to_xarray ()) # type: ignore
483
487
return self .idata # type: ignore
484
488
489
+ def _fit_MAP (
490
+ self ,
491
+ ** kwargs ,
492
+ ):
493
+ """Find model maximum a posteriori using scipy optimizer"""
494
+
495
+ model = self .model
496
+ find_MAP_args = {** self .sampler_config , ** kwargs }
497
+ if "random_seed" in find_MAP_args :
498
+ # find_MAP takes a different argument name for seed than sample_* do.
499
+ find_MAP_args ["seed" ] = find_MAP_args ["random_seed" ]
500
+ # Extra unknown arguments cause problems for SciPy minimize
501
+ allowed_args = [ # find_MAP args
502
+ "start" ,
503
+ "vars" ,
504
+ "method" ,
505
+ # "return_raw", # probably causes a problem if set spuriously
506
+ # "include_transformed", # probably causes a problem if set spuriously
507
+ "progressbar" ,
508
+ "maxeval" ,
509
+ "seed" ,
510
+ ]
511
+ allowed_args += [ # scipy.optimize.minimize args
512
+ # "fun", # used by find_MAP
513
+ # "x0", # used by find_MAP
514
+ "args" ,
515
+ "method" ,
516
+ # "jac", # used by find_MAP
517
+ # "hess", # probably causes a problem if set spuriously
518
+ # "hessp", # probably causes a problem if set spuriously
519
+ "bounds" ,
520
+ "constraints" ,
521
+ "tol" ,
522
+ "callback" ,
523
+ "options" ,
524
+ ]
525
+ for arg in list (find_MAP_args ):
526
+ if arg not in allowed_args :
527
+ del find_MAP_args [arg ]
528
+
529
+ map_res = pm .find_MAP (model = model , ** find_MAP_args )
530
+ # Filter non-value variables
531
+ value_vars_names = {v .name for v in model .value_vars }
532
+ map_res = {k : v for k , v in map_res .items () if k in value_vars_names }
533
+
534
+ # Convert map result to InferenceData
535
+ map_strace = NDArray (model = model )
536
+ map_strace .setup (draws = 1 , chain = 0 )
537
+ map_strace .record (map_res )
538
+ map_strace .close ()
539
+ trace = MultiTrace ([map_strace ])
540
+ idata = pm .to_inference_data (trace , model = model )
541
+ self .set_idata_attrs (idata )
542
+
543
+ return idata
544
+
485
545
def predict (
486
546
self ,
487
547
X_pred : Union [np .ndarray , pd .DataFrame , pd .Series ],
0 commit comments