Skip to content

Commit 2da3c81

Browse files
authored
Refactor modelbuilder fit (#198)
* First pass at refactoring ModelBuilder.fit Added method-specific inference functions to ModelBuilder and made fit call them based on a fit_method parameter. Tests coming in the next commit. Also need to try to factor out code common to both methods. * Add tests for ModelBuilder.fit with different fit_method values Added fit_method as a parameter to the tests so that test models are tested with mcmc and MAP fitting methods. * Use context manager for tempfile When test_save_load failed, pytest complained about the temp file not being closed. It didn't seem to cause any problems, but maybe this is nicer. * Refactor ModelBuilder.fit again Refactored the ModelBuilder.fit method again to eliminate code redundancy between the two fitting methods. Eliminated extra kwargs from being passed to scipy.optimize.minimize.
1 parent 4ea3151 commit 2da3c81

File tree

2 files changed

+85
-26
lines changed

2 files changed

+85
-26
lines changed

pymc_experimental/model_builder.py

+73-13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import pandas as pd
2525
import pymc as pm
2626
import xarray as xr
27+
from pymc.backends import NDArray
28+
from pymc.backends.base import MultiTrace
2729
from pymc.util import RandomState
2830

2931
# If scikit-learn is available, use its data validator
@@ -425,6 +427,7 @@ def fit(
425427
self,
426428
X: pd.DataFrame,
427429
y: Optional[pd.Series] = None,
430+
fit_method="mcmc",
428431
progressbar: bool = True,
429432
predictor_names: List[str] = None,
430433
random_seed: RandomState = None,
@@ -441,6 +444,8 @@ def fit(
441444
The training input samples.
442445
y : array-like if sklearn is available, otherwise array, shape (n_obs,)
443446
The target values (real numbers).
447+
fit_method : str
448+
Which method to use to infer model parameters. One of ["mcmc", "MAP"].
444449
progressbar : bool
445450
Specifies whether the fit progressbar should be displayed
446451
predictor_names: List[str] = None,
@@ -449,19 +454,14 @@ def fit(
449454
random_seed : RandomState
450455
Provides sampler with initial random seed for obtaining reproducible samples
451456
**kwargs : Any
452-
Custom sampler settings can be provided in form of keyword arguments.
453-
454-
Returns
455-
-------
456-
self : az.InferenceData
457-
returns inference data of the fitted model.
458-
Examples
459-
--------
460-
>>> model = MyModel()
461-
>>> idata = model.fit(data)
462-
Auto-assigning NUTS sampler...
463-
Initializing NUTS using jitter+adapt_diag...
457+
Parameters to pass to the inference method. See `_fit_mcmc` or `_fit_MAP` for
458+
method-specific parameters.
464459
"""
460+
available_methods = ["mcmc", "MAP"]
461+
if fit_method not in available_methods:
462+
raise ValueError(
463+
f"Inference method {fit_method} not found. Choose one of {available_methods}."
464+
)
465465
if predictor_names is None:
466466
predictor_names = []
467467
if y is None:
@@ -474,14 +474,74 @@ def fit(
474474
sampler_config["progressbar"] = progressbar
475475
sampler_config["random_seed"] = random_seed
476476
sampler_config.update(**kwargs)
477-
self.idata = self.sample_model(**sampler_config)
477+
478+
if fit_method == "mcmc":
479+
self.idata = self.sample_model(**sampler_config)
480+
elif fit_method == "MAP":
481+
self.idata = self._fit_MAP(**sampler_config)
478482

479483
X_df = pd.DataFrame(X, columns=X.columns)
480484
combined_data = pd.concat([X_df, y], axis=1)
481485
assert all(combined_data.columns), "All columns must have non-empty names"
482486
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
483487
return self.idata # type: ignore
484488

489+
def _fit_MAP(
490+
self,
491+
**kwargs,
492+
):
493+
"""Find model maximum a posteriori using scipy optimizer"""
494+
495+
model = self.model
496+
find_MAP_args = {**self.sampler_config, **kwargs}
497+
if "random_seed" in find_MAP_args:
498+
# find_MAP takes a different argument name for seed than sample_* do.
499+
find_MAP_args["seed"] = find_MAP_args["random_seed"]
500+
# Extra unknown arguments cause problems for SciPy minimize
501+
allowed_args = [ # find_MAP args
502+
"start",
503+
"vars",
504+
"method",
505+
# "return_raw", # probably causes a problem if set spuriously
506+
# "include_transformed", # probably causes a problem if set spuriously
507+
"progressbar",
508+
"maxeval",
509+
"seed",
510+
]
511+
allowed_args += [ # scipy.optimize.minimize args
512+
# "fun", # used by find_MAP
513+
# "x0", # used by find_MAP
514+
"args",
515+
"method",
516+
# "jac", # used by find_MAP
517+
# "hess", # probably causes a problem if set spuriously
518+
# "hessp", # probably causes a problem if set spuriously
519+
"bounds",
520+
"constraints",
521+
"tol",
522+
"callback",
523+
"options",
524+
]
525+
for arg in list(find_MAP_args):
526+
if arg not in allowed_args:
527+
del find_MAP_args[arg]
528+
529+
map_res = pm.find_MAP(model=model, **find_MAP_args)
530+
# Filter non-value variables
531+
value_vars_names = {v.name for v in model.value_vars}
532+
map_res = {k: v for k, v in map_res.items() if k in value_vars_names}
533+
534+
# Convert map result to InferenceData
535+
map_strace = NDArray(model=model)
536+
map_strace.setup(draws=1, chain=0)
537+
map_strace.record(map_res)
538+
map_strace.close()
539+
trace = MultiTrace([map_strace])
540+
idata = pm.to_inference_data(trace, model=model)
541+
self.set_idata_attrs(idata)
542+
543+
return idata
544+
485545
def predict(
486546
self,
487547
X_pred: Union[np.ndarray, pd.DataFrame, pd.Series],

pymc_experimental/tests/test_model_builder.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def toy_y(toy_X):
4040
return y
4141

4242

43-
@pytest.fixture(scope="module")
44-
def fitted_model_instance(toy_X, toy_y):
43+
@pytest.fixture(scope="module", params=["mcmc", "MAP"])
44+
def fitted_model_instance(request, toy_X, toy_y):
4545
sampler_config = {
4646
"draws": 500,
4747
"tune": 300,
@@ -54,12 +54,11 @@ def fitted_model_instance(toy_X, toy_y):
5454
"obs_error": 2,
5555
}
5656
model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config)
57-
model.fit(toy_X)
57+
model.fit(toy_X, toy_y, fit_method=request.param)
5858
return model
5959

6060

6161
class test_ModelBuilder(ModelBuilder):
62-
6362
_model_type = "LinearModel"
6463
version = "0.1"
6564

@@ -151,9 +150,10 @@ def test_fit(fitted_model_instance):
151150
post_pred[fitted_model_instance.output_var].shape[0] == prediction_data.input.shape
152151

153152

154-
def test_fit_no_y(toy_X):
153+
@pytest.mark.parametrize("fit_method", ["mcmc", "MAP"])
154+
def test_fit_no_y(toy_X, fit_method):
155155
model_builder = test_ModelBuilder()
156-
model_builder.idata = model_builder.fit(X=toy_X)
156+
model_builder.idata = model_builder.fit(X=toy_X, fit_method=fit_method)
157157
assert model_builder.model is not None
158158
assert model_builder.idata is not None
159159
assert "posterior" in model_builder.idata.groups()
@@ -163,17 +163,16 @@ def test_fit_no_y(toy_X):
163163
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
164164
)
165165
def test_save_load(fitted_model_instance):
166-
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
167-
fitted_model_instance.save(temp.name)
168-
test_builder2 = test_ModelBuilder.load(temp.name)
169-
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
166+
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) as temp:
167+
fitted_model_instance.save(temp.name)
168+
test_builder2 = test_ModelBuilder.load(temp.name)
169+
assert sorted(fitted_model_instance.idata.groups()) == sorted(test_builder2.idata.groups())
170170

171171
x_pred = np.random.uniform(low=0, high=1, size=100)
172172
prediction_data = pd.DataFrame({"input": x_pred})
173173
pred1 = fitted_model_instance.predict(prediction_data["input"])
174174
pred2 = test_builder2.predict(prediction_data["input"])
175175
assert pred1.shape == pred2.shape
176-
temp.close()
177176

178177

179178
def test_predict(fitted_model_instance):
@@ -193,8 +192,8 @@ def test_sample_posterior_predictive(fitted_model_instance, combined):
193192
pred = fitted_model_instance.sample_posterior_predictive(
194193
prediction_data["input"], combined=combined, extend_idata=True
195194
)
196-
chains = fitted_model_instance.idata.sample_stats.dims["chain"]
197-
draws = fitted_model_instance.idata.sample_stats.dims["draw"]
195+
chains = fitted_model_instance.idata.posterior.dims["chain"]
196+
draws = fitted_model_instance.idata.posterior.dims["draw"]
198197
expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
199198
assert pred[fitted_model_instance.output_var].shape == expected_shape
200199
assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)

0 commit comments

Comments
 (0)