15
15
16
16
import hashlib
17
17
import json
18
+ from abc import abstractmethod
18
19
from pathlib import Path
19
20
from typing import Dict , Union
20
21
24
25
import pymc as pm
25
26
26
27
27
- class ModelBuilder ( pm . Model ) :
28
+ class ModelBuilder :
28
29
"""
29
30
ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models
30
31
and help with deployment.
31
-
32
- Extends the pymc.Model class.
33
32
"""
34
33
35
34
_model_type = "BaseClass"
36
35
version = "None"
37
36
38
37
def __init__ (
39
38
self ,
40
- model_config : Dict ,
41
- sampler_config : Dict ,
42
- data : Dict [ str , Union [ np . ndarray , pd . DataFrame , pd . Series ]] = None ,
39
+ data : Dict [ str , Union [ np . ndarray , pd . DataFrame , pd . Series ]] ,
40
+ model_config : Dict = None ,
41
+ sampler_config : Dict = None ,
43
42
):
44
43
"""
45
44
Initializes model configuration and sampler configuration for the model
46
45
47
46
Parameters
48
47
----------
49
- model_config : Dictionary
48
+ model_config : Dictionary, optional
50
49
dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method.
51
- sampler_config : Dictionary
52
- dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
53
- data : Dictionary
50
+ data : Dictionary, required
54
51
It is the data we need to train the model on.
52
+ sampler_config : Dictionary, optional
53
+ dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
55
54
Examples
56
55
--------
57
56
>>> class LinearModel(ModelBuilder):
58
57
>>> ...
59
58
>>> model = LinearModel(model_config, sampler_config)
60
59
"""
61
60
62
- super ().__init__ ()
61
+ if sampler_config is None :
62
+ sampler_config = {}
63
+ if model_config is None :
64
+ model_config = {}
63
65
self .model_config = model_config # parameters for priors etc.
64
- self .sample_config = sampler_config # parameters for sampling
65
- self .idata = None # inference data object
66
+ self .sampler_config = sampler_config # parameters for sampling
66
67
self .data = data
67
- self .build ()
68
-
69
- def build (self ):
70
- """
71
- Builds the defined model.
72
- """
73
-
74
- with self :
75
- self .build_model (self .model_config , self .data )
68
+ self .idata = (
69
+ None # inference data object placeholder, idata is generated during build execution
70
+ )
76
71
72
+ @abstractmethod
77
73
def _data_setter (
78
74
self , data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]], x_only : bool = True
79
75
):
@@ -100,8 +96,9 @@ def _data_setter(
100
96
101
97
raise NotImplementedError
102
98
103
- @classmethod
104
- def create_sample_input (cls ):
99
+ @staticmethod
100
+ @abstractmethod
101
+ def create_sample_input ():
105
102
"""
106
103
Needs to be implemented by the user in the inherited class.
107
104
Returns examples for data, model_config, sampler_config.
@@ -135,7 +132,7 @@ def create_sample_input(cls):
135
132
136
133
raise NotImplementedError
137
134
138
- def save (self , fname ) :
135
+ def save (self , fname : str ) -> None :
139
136
"""
140
137
Saves inference data of the model.
141
138
@@ -159,8 +156,9 @@ def save(self, fname):
159
156
self .idata .to_netcdf (file )
160
157
161
158
@classmethod
162
- def load (cls , fname ):
159
+ def load (cls , fname : str ):
163
160
"""
161
+ Creates a ModelBuilder instance from a file,
164
162
Loads inference data for the model.
165
163
166
164
Parameters
@@ -170,7 +168,7 @@ def load(cls, fname):
170
168
171
169
Returns
172
170
-------
173
- Returns the inference data that is loaded from local system .
171
+ Returns an instance of ModelBuilder .
174
172
175
173
Raises
176
174
------
@@ -187,22 +185,25 @@ def load(cls, fname):
187
185
188
186
filepath = Path (str (fname ))
189
187
idata = az .from_netcdf (filepath )
190
- self = cls (
191
- json .loads (idata .attrs ["model_config" ]),
192
- json .loads (idata .attrs ["sampler_config" ]),
193
- idata .fit_data .to_dataframe (),
188
+ model_builder = cls (
189
+ model_config = json .loads (idata .attrs ["model_config" ]),
190
+ sampler_config = json .loads (idata .attrs ["sampler_config" ]),
191
+ data = idata .fit_data .to_dataframe (),
194
192
)
195
- self .idata = idata
196
- if self .id != idata .attrs ["id" ]:
193
+ model_builder .idata = idata
194
+ model_builder .build ()
195
+ if model_builder .id != idata .attrs ["id" ]:
197
196
raise ValueError (
198
- f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ self ._model_type } '"
197
+ f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ cls ._model_type } '"
199
198
)
200
199
201
- return self
200
+ return model_builder
202
201
203
- def fit (self , data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ):
202
+ def fit (
203
+ self , data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None
204
+ ) -> az .InferenceData :
204
205
"""
205
- As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
206
+ Fit a model using the data passed as a parameter.
206
207
Sets attrs to inference data of the model.
207
208
208
209
Parameter
@@ -223,37 +224,40 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
223
224
Initializing NUTS using jitter+adapt_diag...
224
225
"""
225
226
227
+ # If a new data was provided, assign it to the model
226
228
if data is not None :
227
229
self .data = data
228
- self ._data_setter (data )
229
230
230
- if self .basic_RVs == []:
231
- self .build ( )
231
+ self .build ()
232
+ self ._data_setter ( data )
232
233
233
- with self :
234
- self .idata = pm .sample (** self .sample_config )
234
+ with self . model :
235
+ self .idata = pm .sample (** self .sampler_config )
235
236
self .idata .extend (pm .sample_prior_predictive ())
236
237
self .idata .extend (pm .sample_posterior_predictive (self .idata ))
237
238
238
239
self .idata .attrs ["id" ] = self .id
239
240
self .idata .attrs ["model_type" ] = self ._model_type
240
241
self .idata .attrs ["version" ] = self .version
241
- self .idata .attrs ["sampler_config" ] = json .dumps (self .sample_config )
242
+ self .idata .attrs ["sampler_config" ] = json .dumps (self .sampler_config )
242
243
self .idata .attrs ["model_config" ] = json .dumps (self .model_config )
243
244
self .idata .add_groups (fit_data = self .data .to_xarray ())
244
245
return self .idata
245
246
246
247
def predict (
247
248
self ,
248
249
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
249
- ):
250
+ extend_idata : bool = True ,
251
+ ) -> dict :
250
252
"""
251
253
Uses model to predict on unseen data and return point prediction of all the samples
252
254
253
255
Parameters
254
256
---------
255
257
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
256
258
It is the data we need to make prediction on using the model.
259
+ extend_idata : Boolean determining whether the predictions should be added to inference data object.
260
+ Defaults to True.
257
261
258
262
Returns
259
263
-------
@@ -275,7 +279,8 @@ def predict(
275
279
276
280
with self .model : # sample with new input data
277
281
post_pred = pm .sample_posterior_predictive (self .idata )
278
-
282
+ if extend_idata :
283
+ self .idata .extend (post_pred )
279
284
# reshape output
280
285
post_pred = self ._extract_samples (post_pred )
281
286
for key in post_pred :
@@ -286,16 +291,17 @@ def predict(
286
291
def predict_posterior (
287
292
self ,
288
293
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
289
- ):
294
+ extend_idata : bool = True ,
295
+ ) -> Dict [str , np .array ]:
290
296
"""
291
297
Uses model to predict samples on unseen data.
292
298
293
299
Parameters
294
300
---------
295
301
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
296
302
It is the data we need to make prediction on using the model.
297
- point_estimate : bool
298
- Adds point like estimate used as mean passed as
303
+ extend_idata : Boolean determining whether the predictions should be added to inference data object.
304
+ Defaults to True.
299
305
300
306
Returns
301
307
-------
@@ -317,6 +323,8 @@ def predict_posterior(
317
323
318
324
with self .model : # sample with new input data
319
325
post_pred = pm .sample_posterior_predictive (self .idata )
326
+ if extend_idata :
327
+ self .idata .extend (post_pred )
320
328
321
329
# reshape output
322
330
post_pred = self ._extract_samples (post_pred )
@@ -357,5 +365,4 @@ def id(self) -> str:
357
365
hasher .update (str (self .model_config .values ()).encode ())
358
366
hasher .update (self .version .encode ())
359
367
hasher .update (self ._model_type .encode ())
360
- # hasher.update(str(self.sample_config.values()).encode())
361
368
return hasher .hexdigest ()[:16 ]
0 commit comments