@@ -180,9 +180,11 @@ def save(self, fname: str) -> None:
180
180
>>> name = './mymodel.nc'
181
181
>>> model.save(name)
182
182
"""
183
-
184
- file = Path (str (fname ))
185
- self .idata .to_netcdf (file )
183
+ if self .idata is not None and "fit_data" in self .idata :
184
+ file = Path (str (fname ))
185
+ self .idata .to_netcdf (file )
186
+ else :
187
+ raise RuntimeError ("The model hasn't been fit yet, call .fit() first" )
186
188
187
189
@classmethod
188
190
def load (cls , fname : str ):
@@ -220,7 +222,7 @@ def load(cls, fname: str):
220
222
data = idata .fit_data .to_dataframe (),
221
223
)
222
224
model_builder .idata = idata
223
- model_builder .idata = model_builder .fit ( )
225
+ model_builder .build_model ( model_builder . data , model_builder .model_config )
224
226
if model_builder .id != idata .attrs ["id" ]:
225
227
raise ValueError (
226
228
f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ cls ._model_type } '"
@@ -261,11 +263,12 @@ def fit(
261
263
# If a new data was provided, assign it to the model
262
264
if data is not None :
263
265
self .data = data
264
- self .model_data , self .model_config , self .sampler_config = self .create_sample_input (
265
- data = self .data
266
- )
266
+ self .model_data , model_config , sampler_config = self .create_sample_input (data = self .data )
267
+ if self .model_config is None :
268
+ self .model_config = model_config
269
+ if self .sampler_config is None :
270
+ self .sampler_config = sampler_config
267
271
self .build_model (self .model_data , self .model_config )
268
- self ._data_setter (self .model_data )
269
272
with self .model :
270
273
self .idata = pm .sample (** self .sampler_config , ** kwargs )
271
274
self .idata .extend (pm .sample_prior_predictive ())
@@ -275,7 +278,7 @@ def fit(
275
278
self .idata .attrs ["model_type" ] = self ._model_type
276
279
self .idata .attrs ["version" ] = self .version
277
280
self .idata .attrs ["sampler_config" ] = json .dumps (self .sampler_config )
278
- self .idata .attrs ["model_config" ] = json .dumps (self .serializable_model_config )
281
+ self .idata .attrs ["model_config" ] = json .dumps (self ._serializable_model_config )
279
282
self .idata .add_groups (fit_data = self .data .to_xarray ())
280
283
return self .idata
281
284
@@ -386,6 +389,19 @@ def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[st
386
389
387
390
return post_pred_dict
388
391
392
+ @property
393
+ @abstractmethod
394
+ def _serializable_model_config (self ) -> Dict [str , Union [int , float , Dict ]]:
395
+ """
396
+ Converts non-serializable values from model_config to their serializable reversable equivalent.
397
+ Data types like pandas DataFrame, Series or datetime aren't JSON serializable,
398
+ so in order to save the model they need to be formatted.
399
+
400
+ Returns
401
+ -------
402
+ model_config: dict
403
+ """
404
+
389
405
@property
390
406
def id (self ) -> str :
391
407
"""
0 commit comments