Skip to content

Commit 37c4161

Browse files
author
Stijn de Boer (imec-OnePlanet)
committed
Rework of modelbuilder save, load, fit and predict
- `generate_and_preprocess_data` is replaced by `preprocess_data` and `save_model_coords`. Both are called in `fit`, and `predict` calls only `preprocess_data`. - self.X and self.y are no longer used - fit data is not stored in idata, because this could lead to data sharing issues - 'model_coords' are now saved in `idata.attrs`. This enables reconstruction of the correct computation graph if correct test X is provided.
1 parent 9999e0f commit 37c4161

File tree

2 files changed

+589
-38
lines changed

2 files changed

+589
-38
lines changed

notebooks/modelbuilder_example.ipynb

+541
Large diffs are not rendered by default.

pymc_experimental/model_builder.py

+48-38
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def __init__(
7474
sampler_config = self.default_sampler_config if sampler_config is None else sampler_config
7575
self.sampler_config = sampler_config
7676
model_config = self.default_model_config if model_config is None else model_config
77-
7877
self.model_config = model_config # parameters for priors etc.
78+
self.model_coords = None
7979
self.model = None # Set by build_model
8080
self.idata: Optional[az.InferenceData] = None # idata is generated during fitting
8181
self.is_fitted_ = False
@@ -172,7 +172,7 @@ def default_sampler_config(self) -> Dict:
172172
--------
173173
>>> @classmethod
174174
>>> def default_sampler_config(self):
175-
>>> Return {
175+
>>> return {
176176
>>> 'draws': 1_000,
177177
>>> 'tune': 1_000,
178178
>>> 'chains': 1,
@@ -187,13 +187,10 @@ def default_sampler_config(self) -> Dict:
187187
raise NotImplementedError
188188

189189
@abstractmethod
190-
def generate_and_preprocess_model_data(
191-
self, X: Union[pd.DataFrame, pd.Series], y: pd.Series
192-
) -> None:
190+
def preprocess_model_data(self, X: Union[pd.DataFrame, pd.Series], y: pd.Series = None) -> None:
193191
"""
194192
Applies preprocessing to the data before fitting the model.
195193
if validate is True, it will check if the data is valid for the model.
196-
sets self.model_coords based on provided dataset
197194
198195
Parameters:
199196
X : array, shape (n_obs, n_features)
@@ -202,17 +199,16 @@ def generate_and_preprocess_model_data(
202199
Examples
203200
--------
204201
>>> @classmethod
205-
>>> def generate_and_preprocess_model_data(self, X, y):
202+
>>> def preprocess_model_data(self, X, y):
206203
>>> x = np.linspace(start=1, stop=50, num=100)
207204
>>> y = 5 * x + 3 + np.random.normal(0, 1, len(x)) * np.random.rand(100)*10 + np.random.rand(100)*6.4
208205
>>> X = pd.DataFrame(x, columns=['x'])
209206
>>> y = pd.Series(y, name='y')
210-
>>> self.X = X
211-
>>> self.y = y
207+
>>> return X, y
212208
213209
Returns
214210
-------
215-
None
211+
pd.DataFrame, pd.Series
216212
217213
"""
218214
raise NotImplementedError
@@ -258,6 +254,23 @@ def build_model(
258254
"""
259255
raise NotImplementedError
260256

257+
def save_model_coords(self, X: Union[pd.DataFrame, pd.Series], y: pd.Series):
258+
"""Creates the model coords.
259+
260+
Parameters:
261+
X : array, shape (n_obs, n_features)
262+
y : array, shape (n_obs,)
263+
264+
Examples
265+
--------
266+
def set_model_coords(self, X, y):
267+
group_dim1 = X['group1'].unique()
268+
group_dim2 = X['group2'].unique()
269+
270+
self.model_coords = {'group1':group_dim1, 'group2':group_dim2}
271+
"""
272+
self.model_coords = None
273+
261274
def sample_model(self, **kwargs):
262275
"""
263276
Sample from the PyMC model.
@@ -339,6 +352,7 @@ def set_idata_attrs(self, idata=None):
339352
idata.attrs["version"] = self.version
340353
idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
341354
idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
355+
idata.attrs["model_coords"] = json.dumps(self.model_coords)
342356
# Only classes with non-dataset parameters will implement save_input_params
343357
if hasattr(self, "_save_input_params"):
344358
self._save_input_params(idata)
@@ -432,18 +446,12 @@ def load(cls, fname: str):
432446
model_config=model_config,
433447
sampler_config=json.loads(idata.attrs["sampler_config"]),
434448
)
449+
model.model_coords = json.loads(idata.attrs["model_coords"])
435450
model.idata = idata
436-
dataset = idata.fit_data.to_dataframe()
437-
X = dataset.drop(columns=[model.output_var])
438-
y = dataset[model.output_var]
439-
model.build_model(X, y)
440-
# All previously used data is in idata.
441-
442451
if model.id != idata.attrs["id"]:
443452
raise ValueError(
444453
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
445454
)
446-
447455
return model
448456

449457
def fit(
@@ -462,7 +470,7 @@ def fit(
462470
463471
Parameters
464472
----------
465-
X : array-like if sklearn is available, otherwise array, shape (n_obs, n_features)
473+
X : pd.DataFrame (n_obs, n_features)
466474
The training input samples.
467475
y : array-like if sklearn is available, otherwise array, shape (n_obs,)
468476
The target values (real numbers).
@@ -492,26 +500,15 @@ def fit(
492500
if y is None:
493501
y = np.zeros(X.shape[0])
494502
y = pd.DataFrame({self.output_var: y})
495-
self.generate_and_preprocess_model_data(X, y.values.flatten())
496-
self.build_model(self.X, self.y)
503+
X_prep, y_prep = self.preprocess_model_data(X, y.values.flatten())
504+
self.save_model_coords(X_prep, y_prep)
505+
self.build_model(X_prep, y_prep)
497506

498507
sampler_config = self.sampler_config.copy()
499508
sampler_config["progressbar"] = progressbar
500509
sampler_config["random_seed"] = random_seed
501510
sampler_config.update(**kwargs)
502511
self.idata = self.sample_model(**sampler_config)
503-
504-
X_df = pd.DataFrame(X, columns=X.columns)
505-
combined_data = pd.concat([X_df, y], axis=1)
506-
assert all(combined_data.columns), "All columns must have non-empty names"
507-
with warnings.catch_warnings():
508-
warnings.filterwarnings(
509-
"ignore",
510-
category=UserWarning,
511-
message="The group fit_data is not defined in the InferenceData scheme",
512-
)
513-
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
514-
515512
return self.idata # type: ignore
516513

517514
def predict(
@@ -526,7 +523,7 @@ def predict(
526523
527524
Parameters
528525
---------
529-
X_pred : array-like if sklearn is available, otherwise array, shape (n_pred, n_features)
526+
X_pred : pd.DataFrame (n_pred, n_features)
530527
The input data used for prediction.
531528
extend_idata : Boolean determining whether the predictions should be added to inference data object.
532529
Defaults to True.
@@ -545,9 +542,12 @@ def predict(
545542
>>> prediction_data = pd.DataFrame({'input':x_pred})
546543
>>> pred_mean = model.predict(prediction_data)
547544
"""
548-
545+
synth_y = pd.Series(np.zeros(len(X_pred)))
546+
X_pred_prep, y_synth_prep = self.preprocess_model_data(X_pred, synth_y)
547+
if self.model is None:
548+
self.build_model(X_pred_prep, y_synth_prep)
549549
posterior_predictive_samples = self.sample_posterior_predictive(
550-
X_pred, extend_idata, combined=False, **kwargs
550+
X_pred_prep, extend_idata, combined=False, **kwargs
551551
)
552552

553553
if self.output_var not in posterior_predictive_samples:
@@ -652,6 +652,7 @@ def get_params(self, deep=True):
652652
return {
653653
"model_config": self.model_config,
654654
"sampler_config": self.sampler_config,
655+
"model_coords": self.model_coords,
655656
}
656657

657658
def set_params(self, **params):
@@ -660,6 +661,7 @@ def set_params(self, **params):
660661
"""
661662
self.model_config = params["model_config"]
662663
self.sampler_config = params["sampler_config"]
664+
self.model_coords = params["model_coords"]
663665

664666
@property
665667
@abstractmethod
@@ -682,7 +684,11 @@ def predict_proba(
682684
**kwargs,
683685
) -> xr.DataArray:
684686
"""Alias for `predict_posterior`, for consistency with scikit-learn probabilistic estimators."""
685-
return self.predict_posterior(X_pred, extend_idata, combined, **kwargs)
687+
synth_y = pd.Series(np.zeros(len(X_pred)))
688+
X_pred_prep, y_synth_prep = self.preprocess_model_data(X_pred, synth_y)
689+
if self.model is None:
690+
self.build_model(X_pred_prep, y_synth_prep)
691+
return self.predict_posterior(X_pred_prep, extend_idata, combined, **kwargs)
686692

687693
def predict_posterior(
688694
self,
@@ -710,9 +716,13 @@ def predict_posterior(
710716
Posterior predictive samples for each input X_pred
711717
"""
712718

713-
X_pred = self._validate_data(X_pred)
719+
synth_y = pd.Series(np.zeros(len(X_pred)))
720+
X_pred_prep, y_synth_prep = self.preprocess_model_data(X_pred, synth_y)
721+
if self.model is None:
722+
self.build_model(X_pred_prep, y_synth_prep)
723+
714724
posterior_predictive_samples = self.sample_posterior_predictive(
715-
X_pred, extend_idata, combined, **kwargs
725+
X_pred_prep, extend_idata, combined=False, **kwargs
716726
)
717727

718728
if self.output_var not in posterior_predictive_samples:

0 commit comments

Comments
 (0)