@@ -55,9 +55,9 @@ def __init__(
55
55
dictionary of parameters that initialise sampler configuration. Class-default defined by the user default_sampler_config method.
56
56
Examples
57
57
--------
58
- >>> class LinearModel (ModelBuilder):
58
+ >>> class MyModel (ModelBuilder):
59
59
>>> ...
60
- >>> model = LinearModel (model_config, sampler_config)
60
+ >>> model = MyModel (model_config, sampler_config)
61
61
"""
62
62
63
63
if sampler_config is None :
@@ -159,6 +159,12 @@ def generate_model_data(
159
159
) -> pd .DataFrame :
160
160
"""
161
161
Returns a default dataset for a class, can be used as a hint to data formatting required for the class
162
+ If data is not None, dataset will be created from it's content.
163
+
164
+ Parameters:
165
+ data : Union[np.ndarray, pd.DataFrame, pd.Series], optional
166
+ dataset that will replace the default sample data
167
+
162
168
163
169
Examples
164
170
--------
@@ -178,16 +184,16 @@ def generate_model_data(
178
184
179
185
@abstractmethod
180
186
def build_model (
181
- model_data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
187
+ data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
182
188
model_config : Dict [str , Union [int , float , Dict ]] = None ,
183
189
) -> None :
184
190
"""
185
- Creates an instance of pm.Model based on provided model_data and model_config, and
191
+ Creates an instance of pm.Model based on provided data and model_config, and
186
192
attaches it to self.
187
193
188
194
Parameters
189
195
----------
190
- model_data : dict
196
+ data : dict
191
197
Preformated data that is going to be used in the model. For efficiency reasons it should contain only the necesary data columns,
192
198
not entire available dataset since it's going to be encoded into data used to recreate the model.
193
199
If not provided uses data from self.data
@@ -207,7 +213,34 @@ def build_model(
207
213
raise NotImplementedError
208
214
209
215
def sample_model (self , ** kwargs ):
216
+ """
217
+ Sample from the PyMC model.
218
+
219
+ Parameters
220
+ ----------
221
+ **kwargs : dict
222
+ Additional keyword arguments to pass to the PyMC sampler.
223
+
224
+ Returns
225
+ -------
226
+ xarray.Dataset
227
+ The PyMC3 samples dataset.
228
+
229
+ Raises
230
+ ------
231
+ RuntimeError
232
+ If the PyMC model hasn't been built yet.
210
233
234
+ Examples
235
+ --------
236
+ >>> self.build_model()
237
+ >>> idata = self.sample_model(draws=100, tune=10)
238
+ >>> assert isinstance(idata, xr.Dataset)
239
+ >>> assert "posterior" in idata
240
+ >>> assert "prior" in idata
241
+ >>> assert "observed_data" in idata
242
+ >>> assert "log_likelihood" in idata
243
+ """
211
244
if self .model is None :
212
245
raise RuntimeError (
213
246
"The model hasn't been built yet, call .build_model() first or call .fit() instead."
@@ -223,6 +256,34 @@ def sample_model(self, **kwargs):
223
256
return idata
224
257
225
258
def set_idata_attrs (self , idata = None ):
259
+ """
260
+ Set attributes on an InferenceData object.
261
+
262
+ Parameters
263
+ ----------
264
+ idata : arviz.InferenceData, optional
265
+ The InferenceData object to set attributes on.
266
+
267
+ Raises
268
+ ------
269
+ RuntimeError
270
+ If no InferenceData object is provided.
271
+
272
+ Returns
273
+ -------
274
+ None
275
+
276
+ Examples
277
+ --------
278
+ >>> model = MyModel(ModelBuilder)
279
+ >>> idata = az.InferenceData(your_dataset)
280
+ >>> model.set_idata_attrs(idata=idata)
281
+ >>> assert "id" in idata.attrs #this and the following lines are part of doctest, not user manual
282
+ >>> assert "model_type" in idata.attrs
283
+ >>> assert "version" in idata.attrs
284
+ >>> assert "sampler_config" in idata.attrs
285
+ >>> assert "model_config" in idata.attrs
286
+ """
226
287
if idata is None :
227
288
idata = self .idata
228
289
if idata is None :
@@ -235,22 +296,33 @@ def set_idata_attrs(self, idata=None):
235
296
236
297
def save (self , fname : str ) -> None :
237
298
"""
238
- Saves inference data of the model .
299
+ Save the model's inference data to a file .
239
300
240
301
Parameters
241
302
----------
242
- fname : string
243
- This denotes the name with path from where idata should be saved.
303
+ fname : str
304
+ The name and path of the file to save the inference data with model parameters.
305
+
306
+ Returns
307
+ -------
308
+ None
309
+
310
+ Raises
311
+ ------
312
+ RuntimeError
313
+ If the model hasn't been fit yet (no inference data available).
244
314
245
315
Examples
246
316
--------
247
- >>> class LinearModel(ModelBuilder):
248
- >>> ...
249
- >>> data, model_config, sampler_config = LinearModel.create_sample_input()
250
- >>> model = LinearModel(model_config, sampler_config)
251
- >>> idata = model.fit(data)
252
- >>> name = './mymodel.nc'
253
- >>> model.save(name)
317
+ This method is meant to be overridden and implemented by subclasses.
318
+ It should not be called directly on the base abstract class or its instances.
319
+
320
+ >>> class MyModel(ModelBuilder):
321
+ >>> def __init__(self):
322
+ >>> super().__init__()
323
+ >>> model = MyModel()
324
+ >>> model.fit(data)
325
+ >>> model.save('model_results.nc') # This will call the overridden method in MyModel
254
326
"""
255
327
if self .idata is not None and "posterior" in self .idata :
256
328
file = Path (str (fname ))
@@ -259,33 +331,32 @@ def save(self, fname: str) -> None:
259
331
raise RuntimeError ("The model hasn't been fit yet, call .fit() first" )
260
332
261
333
@classmethod
262
- def load (cls , fname : str ):
334
+ def load (cls , fname : str ) -> "ModelBuilder" :
263
335
"""
264
- Creates a ModelBuilder instance from a file,
265
- Loads inference data for the model.
336
+ Create a ModelBuilder instance from a file and load inference data for the model.
266
337
267
338
Parameters
268
339
----------
269
- fname : string
270
- This denotes the name with path from where idata should be loaded from .
340
+ fname : str
341
+ The name and path from which the inference data should be loaded.
271
342
272
343
Returns
273
344
-------
274
- Returns an instance of ModelBuilder.
345
+ ModelBuilder
346
+ An instance of the ModelBuilder class.
275
347
276
348
Raises
277
349
------
278
350
ValueError
279
- If the inference data that is loaded doesn't match with the model.
351
+ If the loaded inference data does not match the model.
280
352
281
353
Examples
282
354
--------
283
- >>> class LinearModel (ModelBuilder):
355
+ >>> class MyModel (ModelBuilder):
284
356
>>> ...
285
357
>>> name = './mymodel.nc'
286
- >>> imported_model = LinearModel .load(name)
358
+ >>> imported_model = MyModel .load(name)
287
359
"""
288
-
289
360
filepath = Path (str (fname ))
290
361
idata = az .from_netcdf (filepath )
291
362
model_builder = cls (
@@ -297,7 +368,7 @@ def load(cls, fname: str):
297
368
model_builder .build_model (model_builder .data , model_builder .model_config )
298
369
if model_builder .id != idata .attrs ["id" ]:
299
370
raise ValueError (
300
- f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ cls ._model_type } '"
371
+ f"The file '{ fname } ' does not contain inference data of the same model or configuration as '{ cls ._model_type } '"
301
372
)
302
373
303
374
return model_builder
@@ -331,28 +402,27 @@ def fit(
331
402
332
403
Examples
333
404
--------
334
- >>> data, model_config, sampler_config = LinearModel.create_sample_input()
335
- >>> model = LinearModel(model_config, sampler_config)
405
+ >>> model = MyModel()
336
406
>>> idata = model.fit(data)
337
407
Auto-assigning NUTS sampler...
338
408
Initializing NUTS using jitter+adapt_diag...
339
409
"""
340
410
341
411
# If a new data was provided, assign it to the model
342
412
if data is not None :
343
- self .data = data
344
- self . model_data = self . generate_model_data ( data = self . data )
413
+ self .data = self . generate_model_data ( data = self . data )
414
+
345
415
if self .model_config is None :
346
416
self .model_config = self .default_model_config
347
417
if self .sampler_config is None :
348
418
self .sampler_config = self .default_sampler_config
349
419
if self .model is None :
350
- self .build_model (self .model_data , self .model_config )
420
+ self .build_model (self .data , self .model_config )
351
421
352
422
self .sampler_config ["progressbar" ] = progressbar
353
423
self .sampler_config ["random_seed" ] = random_seed
354
-
355
- self .idata = self .sample_model (** self . sampler_config )
424
+ temp_sampler_config = { ** self . sampler_config , ** kwargs }
425
+ self .idata = self .sample_model (** temp_sampler_config )
356
426
self .idata .add_groups (fit_data = self .data .to_xarray ())
357
427
return self .idata
358
428
@@ -377,8 +447,7 @@ def predict(
377
447
378
448
Examples
379
449
--------
380
- >>> data, model_config, sampler_config = LinearModel.create_sample_input()
381
- >>> model = LinearModel(model_config, sampler_config)
450
+ >>> model = MyModel()
382
451
>>> idata = model.fit(data)
383
452
>>> x_pred = []
384
453
>>> prediction_data = pd.DataFrame({'input':x_pred})
@@ -398,7 +467,7 @@ def predict_posterior(
398
467
Generate posterior predictive samples on unseen data.
399
468
400
469
Parameters
401
- ---------
470
+ ----------
402
471
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
403
472
It is the data we need to make prediction on using the model.
404
473
extend_idata : Boolean determining whether the predictions should be added to inference data object.
@@ -412,8 +481,7 @@ def predict_posterior(
412
481
413
482
Examples
414
483
--------
415
- >>> data, model_config, sampler_config = LinearModel.create_sample_input()
416
- >>> model = LinearModel(model_config, sampler_config)
484
+ >>> model = MyModel()
417
485
>>> idata = model.fit(data)
418
486
>>> x_pred = []
419
487
>>> prediction_data = pd.DataFrame({'input': x_pred})
@@ -450,13 +518,22 @@ def _serializable_model_config(self) -> Dict[str, Union[int, float, Dict]]:
450
518
@property
451
519
def id (self ) -> str :
452
520
"""
453
- It creates a hash value to match the model version using last 16 characters of hash encoding.
521
+ Generate a unique hash value for the model.
522
+
523
+ The hash value is created using the last 16 characters of the SHA256 hash encoding, based on the model configuration,
524
+ version, and model type.
454
525
455
526
Returns
456
527
-------
457
- Returns string of length 16 characters contains unique hash of the model
458
- """
528
+ str
529
+ A string of length 16 characters containing a unique hash of the model.
459
530
531
+ Examples
532
+ --------
533
+ >>> model = MyModel()
534
+ >>> model.id
535
+ '0123456789abcdef'
536
+ """
460
537
hasher = hashlib .sha256 ()
461
538
hasher .update (str (self .model_config .values ()).encode ())
462
539
hasher .update (self .version .encode ())
0 commit comments