Skip to content

Commit dd04eff

Browse files
Revert "Refactor modelbuilder fit (#198)" (#200)
This reverts commit 2da3c81.
1 parent 356232c commit dd04eff

File tree

2 files changed

+26
-85
lines changed

2 files changed

+26
-85
lines changed

pymc_experimental/model_builder.py

+13-73
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import pandas as pd
2525
import pymc as pm
2626
import xarray as xr
27-
from pymc.backends import NDArray
28-
from pymc.backends.base import MultiTrace
2927
from pymc.util import RandomState
3028

3129
# If scikit-learn is available, use its data validator
@@ -427,7 +425,6 @@ def fit(
427425
self,
428426
X: pd.DataFrame,
429427
y: Optional[pd.Series] = None,
430-
fit_method="mcmc",
431428
progressbar: bool = True,
432429
predictor_names: List[str] = None,
433430
random_seed: RandomState = None,
@@ -444,8 +441,6 @@ def fit(
444441
The training input samples.
445442
y : array-like if sklearn is available, otherwise array, shape (n_obs,)
446443
The target values (real numbers).
447-
fit_method : str
448-
Which method to use to infer model parameters. One of ["mcmc", "MAP"].
449444
progressbar : bool
450445
Specifies whether the fit progressbar should be displayed
451446
predictor_names: List[str] = None,
@@ -454,14 +449,19 @@ def fit(
454449
random_seed : RandomState
455450
Provides sampler with initial random seed for obtaining reproducible samples
456451
**kwargs : Any
457-
Parameters to pass to the inference method. See `_fit_mcmc` or `_fit_MAP` for
458-
method-specific parameters.
452+
Custom sampler settings can be provided in form of keyword arguments.
453+
454+
Returns
455+
-------
456+
self : az.InferenceData
457+
returns inference data of the fitted model.
458+
Examples
459+
--------
460+
>>> model = MyModel()
461+
>>> idata = model.fit(data)
462+
Auto-assigning NUTS sampler...
463+
Initializing NUTS using jitter+adapt_diag...
459464
"""
460-
available_methods = ["mcmc", "MAP"]
461-
if fit_method not in available_methods:
462-
raise ValueError(
463-
f"Inference method {fit_method} not found. Choose one of {available_methods}."
464-
)
465465
if predictor_names is None:
466466
predictor_names = []
467467
if y is None:
@@ -474,74 +474,14 @@ def fit(
474474
sampler_config["progressbar"] = progressbar
475475
sampler_config["random_seed"] = random_seed
476476
sampler_config.update(**kwargs)
477-
478-
if fit_method == "mcmc":
479-
self.idata = self.sample_model(**sampler_config)
480-
elif fit_method == "MAP":
481-
self.idata = self._fit_MAP(**sampler_config)
477+
self.idata = self.sample_model(**sampler_config)
482478

483479
X_df = pd.DataFrame(X, columns=X.columns)
484480
combined_data = pd.concat([X_df, y], axis=1)
485481
assert all(combined_data.columns), "All columns must have non-empty names"
486482
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
487483
return self.idata # type: ignore
488484

489-
def _fit_MAP(
490-
self,
491-
**kwargs,
492-
):
493-
"""Find model maximum a posteriori using scipy optimizer"""
494-
495-
model = self.model
496-
find_MAP_args = {**self.sampler_config, **kwargs}
497-
if "random_seed" in find_MAP_args:
498-
# find_MAP takes a different argument name for seed than sample_* do.
499-
find_MAP_args["seed"] = find_MAP_args["random_seed"]
500-
# Extra unknown arguments cause problems for SciPy minimize
501-
allowed_args = [ # find_MAP args
502-
"start",
503-
"vars",
504-
"method",
505-
# "return_raw", # probably causes a problem if set spuriously
506-
# "include_transformed", # probably causes a problem if set spuriously
507-
"progressbar",
508-
"maxeval",
509-
"seed",
510-
]
511-
allowed_args += [ # scipy.optimize.minimize args
512-
# "fun", # used by find_MAP
513-
# "x0", # used by find_MAP
514-
"args",
515-
"method",
516-
# "jac", # used by find_MAP
517-
# "hess", # probably causes a problem if set spuriously
518-
# "hessp", # probably causes a problem if set spuriously
519-
"bounds",
520-
"constraints",
521-
"tol",
522-
"callback",
523-
"options",
524-
]
525-
for arg in list(find_MAP_args):
526-
if arg not in allowed_args:
527-
del find_MAP_args[arg]
528-
529-
map_res = pm.find_MAP(model=model, **find_MAP_args)
530-
# Filter non-value variables
531-
value_vars_names = {v.name for v in model.value_vars}
532-
map_res = {k: v for k, v in map_res.items() if k in value_vars_names}
533-
534-
# Convert map result to InferenceData
535-
map_strace = NDArray(model=model)
536-
map_strace.setup(draws=1, chain=0)
537-
map_strace.record(map_res)
538-
map_strace.close()
539-
trace = MultiTrace([map_strace])
540-
idata = pm.to_inference_data(trace, model=model)
541-
self.set_idata_attrs(idata)
542-
543-
return idata
544-
545485
def predict(
546486
self,
547487
X_pred: Union[np.ndarray, pd.DataFrame, pd.Series],

pymc_experimental/tests/test_model_builder.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def toy_y(toy_X):
4040
return y
4141

4242

43-
@pytest.fixture(scope="module", params=["mcmc", "MAP"])
44-
def fitted_model_instance(request, toy_X, toy_y):
43+
@pytest.fixture(scope="module")
44+
def fitted_model_instance(toy_X, toy_y):
4545
sampler_config = {
4646
"draws": 500,
4747
"tune": 300,
@@ -54,11 +54,12 @@ def fitted_model_instance(request, toy_X, toy_y):
5454
"obs_error": 2,
5555
}
5656
model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config)
57-
model.fit(toy_X, toy_y, fit_method=request.param)
57+
model.fit(toy_X)
5858
return model
5959

6060

6161
class test_ModelBuilder(ModelBuilder):
62+
6263
_model_type = "LinearModel"
6364
version = "0.1"
6465

@@ -150,10 +151,9 @@ def test_fit(fitted_model_instance):
150151
post_pred[fitted_model_instance.output_var].shape[0] == prediction_data.input.shape
151152

152153

153-
@pytest.mark.parametrize("fit_method", ["mcmc", "MAP"])
154-
def test_fit_no_y(toy_X, fit_method):
154+
def test_fit_no_y(toy_X):
155155
model_builder = test_ModelBuilder()
156-
model_builder.idata = model_builder.fit(X=toy_X, fit_method=fit_method)
156+
model_builder.idata = model_builder.fit(X=toy_X)
157157
assert model_builder.model is not None
158158
assert model_builder.idata is not None
159159
assert "posterior" in model_builder.idata.groups()
@@ -163,16 +163,17 @@ def test_fit_no_y(toy_X, fit_method):
163163
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
164164
)
165165
def test_save_load(fitted_model_instance):
166-
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) as temp:
167-
fitted_model_instance.save(temp.name)
168-
test_builder2 = test_ModelBuilder.load(temp.name)
169-
assert sorted(fitted_model_instance.idata.groups()) == sorted(test_builder2.idata.groups())
166+
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
167+
fitted_model_instance.save(temp.name)
168+
test_builder2 = test_ModelBuilder.load(temp.name)
169+
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
170170

171171
x_pred = np.random.uniform(low=0, high=1, size=100)
172172
prediction_data = pd.DataFrame({"input": x_pred})
173173
pred1 = fitted_model_instance.predict(prediction_data["input"])
174174
pred2 = test_builder2.predict(prediction_data["input"])
175175
assert pred1.shape == pred2.shape
176+
temp.close()
176177

177178

178179
def test_predict(fitted_model_instance):
@@ -192,8 +193,8 @@ def test_sample_posterior_predictive(fitted_model_instance, combined):
192193
pred = fitted_model_instance.sample_posterior_predictive(
193194
prediction_data["input"], combined=combined, extend_idata=True
194195
)
195-
chains = fitted_model_instance.idata.posterior.dims["chain"]
196-
draws = fitted_model_instance.idata.posterior.dims["draw"]
196+
chains = fitted_model_instance.idata.sample_stats.dims["chain"]
197+
draws = fitted_model_instance.idata.sample_stats.dims["draw"]
197198
expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
198199
assert pred[fitted_model_instance.output_var].shape == expected_shape
199200
assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)

0 commit comments

Comments
 (0)