24
24
import pandas as pd
25
25
import pymc as pm
26
26
import xarray as xr
27
- from pymc .backends import NDArray
28
- from pymc .backends .base import MultiTrace
29
27
from pymc .util import RandomState
30
28
31
29
# If scikit-learn is available, use its data validator
@@ -427,7 +425,6 @@ def fit(
427
425
self ,
428
426
X : pd .DataFrame ,
429
427
y : Optional [pd .Series ] = None ,
430
- fit_method = "mcmc" ,
431
428
progressbar : bool = True ,
432
429
predictor_names : List [str ] = None ,
433
430
random_seed : RandomState = None ,
@@ -444,8 +441,6 @@ def fit(
444
441
The training input samples.
445
442
y : array-like if sklearn is available, otherwise array, shape (n_obs,)
446
443
The target values (real numbers).
447
- fit_method : str
448
- Which method to use to infer model parameters. One of ["mcmc", "MAP"].
449
444
progressbar : bool
450
445
Specifies whether the fit progressbar should be displayed
451
446
predictor_names: List[str] = None,
@@ -454,14 +449,19 @@ def fit(
454
449
random_seed : RandomState
455
450
Provides sampler with initial random seed for obtaining reproducible samples
456
451
**kwargs : Any
457
- Parameters to pass to the inference method. See `_fit_mcmc` or `_fit_MAP` for
458
- method-specific parameters.
452
+ Custom sampler settings can be provided in form of keyword arguments.
453
+
454
+ Returns
455
+ -------
456
+ self : az.InferenceData
457
+ returns inference data of the fitted model.
458
+ Examples
459
+ --------
460
+ >>> model = MyModel()
461
+ >>> idata = model.fit(data)
462
+ Auto-assigning NUTS sampler...
463
+ Initializing NUTS using jitter+adapt_diag...
459
464
"""
460
- available_methods = ["mcmc" , "MAP" ]
461
- if fit_method not in available_methods :
462
- raise ValueError (
463
- f"Inference method { fit_method } not found. Choose one of { available_methods } ."
464
- )
465
465
if predictor_names is None :
466
466
predictor_names = []
467
467
if y is None :
@@ -474,74 +474,14 @@ def fit(
474
474
sampler_config ["progressbar" ] = progressbar
475
475
sampler_config ["random_seed" ] = random_seed
476
476
sampler_config .update (** kwargs )
477
-
478
- if fit_method == "mcmc" :
479
- self .idata = self .sample_model (** sampler_config )
480
- elif fit_method == "MAP" :
481
- self .idata = self ._fit_MAP (** sampler_config )
477
+ self .idata = self .sample_model (** sampler_config )
482
478
483
479
X_df = pd .DataFrame (X , columns = X .columns )
484
480
combined_data = pd .concat ([X_df , y ], axis = 1 )
485
481
assert all (combined_data .columns ), "All columns must have non-empty names"
486
482
self .idata .add_groups (fit_data = combined_data .to_xarray ()) # type: ignore
487
483
return self .idata # type: ignore
488
484
489
- def _fit_MAP (
490
- self ,
491
- ** kwargs ,
492
- ):
493
- """Find model maximum a posteriori using scipy optimizer"""
494
-
495
- model = self .model
496
- find_MAP_args = {** self .sampler_config , ** kwargs }
497
- if "random_seed" in find_MAP_args :
498
- # find_MAP takes a different argument name for seed than sample_* do.
499
- find_MAP_args ["seed" ] = find_MAP_args ["random_seed" ]
500
- # Extra unknown arguments cause problems for SciPy minimize
501
- allowed_args = [ # find_MAP args
502
- "start" ,
503
- "vars" ,
504
- "method" ,
505
- # "return_raw", # probably causes a problem if set spuriously
506
- # "include_transformed", # probably causes a problem if set spuriously
507
- "progressbar" ,
508
- "maxeval" ,
509
- "seed" ,
510
- ]
511
- allowed_args += [ # scipy.optimize.minimize args
512
- # "fun", # used by find_MAP
513
- # "x0", # used by find_MAP
514
- "args" ,
515
- "method" ,
516
- # "jac", # used by find_MAP
517
- # "hess", # probably causes a problem if set spuriously
518
- # "hessp", # probably causes a problem if set spuriously
519
- "bounds" ,
520
- "constraints" ,
521
- "tol" ,
522
- "callback" ,
523
- "options" ,
524
- ]
525
- for arg in list (find_MAP_args ):
526
- if arg not in allowed_args :
527
- del find_MAP_args [arg ]
528
-
529
- map_res = pm .find_MAP (model = model , ** find_MAP_args )
530
- # Filter non-value variables
531
- value_vars_names = {v .name for v in model .value_vars }
532
- map_res = {k : v for k , v in map_res .items () if k in value_vars_names }
533
-
534
- # Convert map result to InferenceData
535
- map_strace = NDArray (model = model )
536
- map_strace .setup (draws = 1 , chain = 0 )
537
- map_strace .record (map_res )
538
- map_strace .close ()
539
- trace = MultiTrace ([map_strace ])
540
- idata = pm .to_inference_data (trace , model = model )
541
- self .set_idata_attrs (idata )
542
-
543
- return idata
544
-
545
485
def predict (
546
486
self ,
547
487
X_pred : Union [np .ndarray , pd .DataFrame , pd .Series ],
0 commit comments