@@ -37,8 +37,8 @@ class ModelBuilder:
37
37
def __init__ (
38
38
self ,
39
39
model_config : Dict ,
40
- sampler_config : Dict ,
41
- data : Dict [ str , Union [ np . ndarray , pd . DataFrame , pd . Series ]] = None ,
40
+ data : Dict [ str , Union [ np . ndarray , pd . DataFrame , pd . Series ]] ,
41
+ sampler_config : Dict = None ,
42
42
):
43
43
"""
44
44
Initializes model configuration and sampler configuration for the model
@@ -47,10 +47,10 @@ def __init__(
47
47
----------
48
48
model_config : Dictionary
49
49
dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method.
50
- sampler_config : Dictionary
51
- dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
52
50
data : Dictionary
53
51
It is the data we need to train the model on.
52
+ sampler_config : Dictionary
53
+ dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
54
54
Examples
55
55
--------
56
56
>>> class LinearModel(ModelBuilder):
@@ -60,9 +60,11 @@ def __init__(
60
60
61
61
super ().__init__ ()
62
62
self .model_config = model_config # parameters for priors etc.
63
- self .sample_config = sampler_config # parameters for sampling
64
- self .idata = None # inference data object
63
+ self .sampler_config = sampler_config # parameters for sampling
65
64
self .data = data
65
+ self .idata = (
66
+ None # inference data object placeholder, idata is generated during build execution
67
+ )
66
68
67
69
def build (self ) -> None :
68
70
"""
@@ -188,16 +190,20 @@ def load(cls, fname: str):
188
190
189
191
filepath = Path (str (fname ))
190
192
idata = az .from_netcdf (filepath )
193
+ if "sampler_config" in idata .attrs :
194
+ sampler_config = json .loads (idata .attrs ["sampler_config" ])
195
+ else :
196
+ sampler_config = None
191
197
model_builder = cls (
192
- json .loads (idata .attrs ["model_config" ]),
193
- json . loads ( idata . attrs [ " sampler_config" ]) ,
194
- idata .fit_data .to_dataframe (),
198
+ model_config = json .loads (idata .attrs ["model_config" ]),
199
+ sampler_config = sampler_config ,
200
+ data = idata .fit_data .to_dataframe (),
195
201
)
196
202
model_builder .idata = idata
197
203
model_builder .build ()
198
204
if model_builder .id != idata .attrs ["id" ]:
199
205
raise ValueError (
200
- f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ self ._model_type } '"
206
+ f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ cls ._model_type } '"
201
207
)
202
208
203
209
return model_builder
@@ -206,7 +212,7 @@ def fit(
206
212
self , data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None
207
213
) -> az .InferenceData :
208
214
"""
209
- As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
215
+ Fit a model using the data passed as a parameter.
210
216
Sets attrs to inference data of the model.
211
217
212
218
Parameter
@@ -233,21 +239,26 @@ def fit(
233
239
self ._data_setter (data )
234
240
235
241
with self .model :
236
- self .idata = pm .sample (** self .sample_config )
242
+ if self .sampler_config is not None :
243
+ self .idata = pm .sample (** self .sampler_config )
244
+ else :
245
+ self .idata = pm .sample ()
237
246
self .idata .extend (pm .sample_prior_predictive ())
238
247
self .idata .extend (pm .sample_posterior_predictive (self .idata ))
239
248
240
249
self .idata .attrs ["id" ] = self .id
241
250
self .idata .attrs ["model_type" ] = self ._model_type
242
251
self .idata .attrs ["version" ] = self .version
243
- self .idata .attrs ["sampler_config" ] = json .dumps (self .sample_config )
252
+ if self .sampler_config is not None :
253
+ self .idata .attrs ["sampler_config" ] = json .dumps (self .sampler_config )
244
254
self .idata .attrs ["model_config" ] = json .dumps (self .model_config )
245
255
self .idata .add_groups (fit_data = self .data .to_xarray ())
246
256
return self .idata
247
257
248
258
def predict (
249
259
self ,
250
260
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
261
+ extend_idata : bool = True ,
251
262
) -> dict :
252
263
"""
253
264
Uses model to predict on unseen data and return point prediction of all the samples
@@ -256,6 +267,8 @@ def predict(
256
267
---------
257
268
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
258
269
It is the data we need to make prediction on using the model.
270
+ extend_idata : Boolean determining whether the predictions should be added to inference data object.
271
+ Defaults to True.
259
272
260
273
Returns
261
274
-------
@@ -277,7 +290,8 @@ def predict(
277
290
278
291
with self .model : # sample with new input data
279
292
post_pred = pm .sample_posterior_predictive (self .idata )
280
-
293
+ if extend_idata :
294
+ self .idata .extend (post_pred )
281
295
# reshape output
282
296
post_pred = self ._extract_samples (post_pred )
283
297
for key in post_pred :
@@ -288,6 +302,7 @@ def predict(
288
302
def predict_posterior (
289
303
self ,
290
304
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
305
+ extend_idata : bool = True ,
291
306
) -> Dict [str , np .array ]:
292
307
"""
293
308
Uses model to predict samples on unseen data.
@@ -296,8 +311,8 @@ def predict_posterior(
296
311
---------
297
312
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
298
313
It is the data we need to make prediction on using the model.
299
- point_estimate : bool
300
- Adds point like estimate used as mean passed as
314
+ extend_idata : Boolean determining whether the predictions should be added to inference data object.
315
+ Defaults to True.
301
316
302
317
Returns
303
318
-------
@@ -319,6 +334,8 @@ def predict_posterior(
319
334
320
335
with self .model : # sample with new input data
321
336
post_pred = pm .sample_posterior_predictive (self .idata )
337
+ if extend_idata :
338
+ self .idata .extend (post_pred )
322
339
323
340
# reshape output
324
341
post_pred = self ._extract_samples (post_pred )
@@ -359,5 +376,4 @@ def id(self) -> str:
359
376
hasher .update (str (self .model_config .values ()).encode ())
360
377
hasher .update (self .version .encode ())
361
378
hasher .update (self ._model_type .encode ())
362
- # hasher.update(str(self.sample_config.values()).encode())
363
379
return hasher .hexdigest ()[:16 ]
0 commit comments