@@ -74,8 +74,8 @@ def __init__(
74
74
sampler_config = self .default_sampler_config if sampler_config is None else sampler_config
75
75
self .sampler_config = sampler_config
76
76
model_config = self .default_model_config if model_config is None else model_config
77
-
78
77
self .model_config = model_config # parameters for priors etc.
78
+ self .model_coords = None
79
79
self .model = None # Set by build_model
80
80
self .idata : Optional [az .InferenceData ] = None # idata is generated during fitting
81
81
self .is_fitted_ = False
@@ -172,7 +172,7 @@ def default_sampler_config(self) -> Dict:
172
172
--------
173
173
>>> @classmethod
174
174
>>> def default_sampler_config(self):
175
- >>> Return {
175
+ >>> return {
176
176
>>> 'draws': 1_000,
177
177
>>> 'tune': 1_000,
178
178
>>> 'chains': 1,
@@ -187,13 +187,10 @@ def default_sampler_config(self) -> Dict:
187
187
raise NotImplementedError
188
188
189
189
@abstractmethod
190
- def generate_and_preprocess_model_data (
191
- self , X : Union [pd .DataFrame , pd .Series ], y : pd .Series
192
- ) -> None :
190
+ def preprocess_model_data (self , X : Union [pd .DataFrame , pd .Series ], y : pd .Series = None ) -> None :
193
191
"""
194
192
Applies preprocessing to the data before fitting the model.
195
193
if validate is True, it will check if the data is valid for the model.
196
- sets self.model_coords based on provided dataset
197
194
198
195
Parameters:
199
196
X : array, shape (n_obs, n_features)
@@ -202,17 +199,16 @@ def generate_and_preprocess_model_data(
202
199
Examples
203
200
--------
204
201
>>> @classmethod
205
- >>> def generate_and_preprocess_model_data (self, X, y):
202
+ >>> def preprocess_model_data (self, X, y):
206
203
>>> x = np.linspace(start=1, stop=50, num=100)
207
204
>>> y = 5 * x + 3 + np.random.normal(0, 1, len(x)) * np.random.rand(100)*10 + np.random.rand(100)*6.4
208
205
>>> X = pd.DataFrame(x, columns=['x'])
209
206
>>> y = pd.Series(y, name='y')
210
- >>> self.X = X
211
- >>> self.y = y
207
+ >>> return X, y
212
208
213
209
Returns
214
210
-------
215
- None
211
+ pd.DataFrame, pd.Series
216
212
217
213
"""
218
214
raise NotImplementedError
@@ -258,6 +254,23 @@ def build_model(
258
254
"""
259
255
raise NotImplementedError
260
256
257
+ def save_model_coords (self , X : Union [pd .DataFrame , pd .Series ], y : pd .Series ):
258
+ """Creates the model coords.
259
+
260
+ Parameters:
261
+ X : array, shape (n_obs, n_features)
262
+ y : array, shape (n_obs,)
263
+
264
+ Examples
265
+ --------
266
+ def set_model_coords(self, X, y):
267
+ group_dim1 = X['group1'].unique()
268
+ group_dim2 = X['group2'].unique()
269
+
270
+ self.model_coords = {'group1':group_dim1, 'group2':group_dim2}
271
+ """
272
+ self .model_coords = None
273
+
261
274
def sample_model (self , ** kwargs ):
262
275
"""
263
276
Sample from the PyMC model.
@@ -339,6 +352,7 @@ def set_idata_attrs(self, idata=None):
339
352
idata .attrs ["version" ] = self .version
340
353
idata .attrs ["sampler_config" ] = json .dumps (self .sampler_config )
341
354
idata .attrs ["model_config" ] = json .dumps (self ._serializable_model_config )
355
+ idata .attrs ["model_coords" ] = json .dumps (self .model_coords )
342
356
# Only classes with non-dataset parameters will implement save_input_params
343
357
if hasattr (self , "_save_input_params" ):
344
358
self ._save_input_params (idata )
@@ -432,18 +446,12 @@ def load(cls, fname: str):
432
446
model_config = model_config ,
433
447
sampler_config = json .loads (idata .attrs ["sampler_config" ]),
434
448
)
449
+ model .model_coords = json .loads (idata .attrs ["model_coords" ])
435
450
model .idata = idata
436
- dataset = idata .fit_data .to_dataframe ()
437
- X = dataset .drop (columns = [model .output_var ])
438
- y = dataset [model .output_var ]
439
- model .build_model (X , y )
440
- # All previously used data is in idata.
441
-
442
451
if model .id != idata .attrs ["id" ]:
443
452
raise ValueError (
444
453
f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ cls ._model_type } '"
445
454
)
446
-
447
455
return model
448
456
449
457
def fit (
@@ -462,7 +470,7 @@ def fit(
462
470
463
471
Parameters
464
472
----------
465
- X : array-like if sklearn is available, otherwise array, shape (n_obs, n_features)
473
+ X : pd.DataFrame (n_obs, n_features)
466
474
The training input samples.
467
475
y : array-like if sklearn is available, otherwise array, shape (n_obs,)
468
476
The target values (real numbers).
@@ -492,26 +500,15 @@ def fit(
492
500
if y is None :
493
501
y = np .zeros (X .shape [0 ])
494
502
y = pd .DataFrame ({self .output_var : y })
495
- self .generate_and_preprocess_model_data (X , y .values .flatten ())
496
- self .build_model (self .X , self .y )
503
+ X_prep , y_prep = self .preprocess_model_data (X , y .values .flatten ())
504
+ self .save_model_coords (X_prep , y_prep )
505
+ self .build_model (X_prep , y_prep )
497
506
498
507
sampler_config = self .sampler_config .copy ()
499
508
sampler_config ["progressbar" ] = progressbar
500
509
sampler_config ["random_seed" ] = random_seed
501
510
sampler_config .update (** kwargs )
502
511
self .idata = self .sample_model (** sampler_config )
503
-
504
- X_df = pd .DataFrame (X , columns = X .columns )
505
- combined_data = pd .concat ([X_df , y ], axis = 1 )
506
- assert all (combined_data .columns ), "All columns must have non-empty names"
507
- with warnings .catch_warnings ():
508
- warnings .filterwarnings (
509
- "ignore" ,
510
- category = UserWarning ,
511
- message = "The group fit_data is not defined in the InferenceData scheme" ,
512
- )
513
- self .idata .add_groups (fit_data = combined_data .to_xarray ()) # type: ignore
514
-
515
512
return self .idata # type: ignore
516
513
517
514
def predict (
@@ -526,7 +523,7 @@ def predict(
526
523
527
524
Parameters
528
525
---------
529
- X_pred : array-like if sklearn is available, otherwise array, shape (n_pred, n_features)
526
+ X_pred : pd.DataFrame (n_pred, n_features)
530
527
The input data used for prediction.
531
528
extend_idata : Boolean determining whether the predictions should be added to inference data object.
532
529
Defaults to True.
@@ -545,9 +542,12 @@ def predict(
545
542
>>> prediction_data = pd.DataFrame({'input':x_pred})
546
543
>>> pred_mean = model.predict(prediction_data)
547
544
"""
548
-
545
+ synth_y = pd .Series (np .zeros (len (X_pred )))
546
+ X_pred_prep , y_synth_prep = self .preprocess_model_data (X_pred , synth_y )
547
+ if self .model is None :
548
+ self .build_model (X_pred_prep , y_synth_prep )
549
549
posterior_predictive_samples = self .sample_posterior_predictive (
550
- X_pred , extend_idata , combined = False , ** kwargs
550
+ X_pred_prep , extend_idata , combined = False , ** kwargs
551
551
)
552
552
553
553
if self .output_var not in posterior_predictive_samples :
@@ -652,6 +652,7 @@ def get_params(self, deep=True):
652
652
return {
653
653
"model_config" : self .model_config ,
654
654
"sampler_config" : self .sampler_config ,
655
+ "model_coords" : self .model_coords ,
655
656
}
656
657
657
658
def set_params (self , ** params ):
@@ -660,6 +661,7 @@ def set_params(self, **params):
660
661
"""
661
662
self .model_config = params ["model_config" ]
662
663
self .sampler_config = params ["sampler_config" ]
664
+ self .model_coords = params ["model_coords" ]
663
665
664
666
@property
665
667
@abstractmethod
@@ -682,7 +684,11 @@ def predict_proba(
682
684
** kwargs ,
683
685
) -> xr .DataArray :
684
686
"""Alias for `predict_posterior`, for consistency with scikit-learn probabilistic estimators."""
685
- return self .predict_posterior (X_pred , extend_idata , combined , ** kwargs )
687
+ synth_y = pd .Series (np .zeros (len (X_pred )))
688
+ X_pred_prep , y_synth_prep = self .preprocess_model_data (X_pred , synth_y )
689
+ if self .model is None :
690
+ self .build_model (X_pred_prep , y_synth_prep )
691
+ return self .predict_posterior (X_pred_prep , extend_idata , combined , ** kwargs )
686
692
687
693
def predict_posterior (
688
694
self ,
@@ -710,9 +716,13 @@ def predict_posterior(
710
716
Posterior predictive samples for each input X_pred
711
717
"""
712
718
713
- X_pred = self ._validate_data (X_pred )
719
+ synth_y = pd .Series (np .zeros (len (X_pred )))
720
+ X_pred_prep , y_synth_prep = self .preprocess_model_data (X_pred , synth_y )
721
+ if self .model is None :
722
+ self .build_model (X_pred_prep , y_synth_prep )
723
+
714
724
posterior_predictive_samples = self .sample_posterior_predictive (
715
- X_pred , extend_idata , combined , ** kwargs
725
+ X_pred_prep , extend_idata , combined = False , ** kwargs
716
726
)
717
727
718
728
if self .output_var not in posterior_predictive_samples :
0 commit comments