17
17
import json
18
18
from abc import abstractmethod
19
19
from pathlib import Path
20
- from typing import Dict , Union
20
+ from typing import Any , Dict , Union
21
21
22
22
import arviz as az
23
23
import numpy as np
24
24
import pandas as pd
25
25
import pymc as pm
26
+ from pymc .util import RandomState
26
27
27
28
28
29
class ModelBuilder :
@@ -100,7 +101,7 @@ def _data_setter(
100
101
@abstractmethod
101
102
def create_sample_input ():
102
103
"""
103
- Needs to be implemented by the user in the inherited class.
104
+ Needs to be implemented by the user in the child class.
104
105
Returns examples for data, model_config, sampler_config.
105
106
This is useful for understanding the required
106
107
data structures for the user model.
@@ -114,12 +115,15 @@ def create_sample_input():
114
115
>>> data = pd.DataFrame({'input': x, 'output': y})
115
116
116
117
>>> model_config = {
117
- >>> 'a_loc': 7,
118
- >>> 'a_scale': 3,
119
- >>> 'b_loc': 5,
120
- >>> 'b_scale': 3,
121
- >>> 'obs_error': 2,
122
- >>> }
118
+ >>> 'a' : {
119
+ >>> 'loc': 7,
120
+ >>> 'scale' : 3
121
+ >>> },
122
+ >>> 'b' : {
123
+ >>> 'loc': 3,
124
+ >>> 'scale': 5
125
+ >>> }
126
+ >>> 'obs_error': 2
123
127
124
128
>>> sampler_config = {
125
129
>>> 'draws': 1_000,
@@ -132,6 +136,31 @@ def create_sample_input():
132
136
133
137
raise NotImplementedError
134
138
139
+ @abstractmethod
140
+ def build_model (
141
+ model_data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]],
142
+ model_config : Dict [str , Union [int , float , Dict ]],
143
+ ) -> None :
144
+ """
145
+ Needs to be implemented by the user in the child class.
146
+ Creates an instance of pm.Model based on provided model_data and model_config, and
147
+ attaches it to self.
148
+
149
+ Required Parameters
150
+ ----------
151
+ model_data - preformated data that is going to be used in the model.
152
+ For efficiency reasons it should contain only the necesary data columns, not entire available
153
+ dataset since it's going to be encoded into data used to recreate the model.
154
+ model_config - dictionary where keys are strings representing names of parameters of the model, values are
155
+ dictionaries of parameters needed for creating model parameters (see example in create_model_input)
156
+
157
+ Returns:
158
+ ----------
159
+ None
160
+
161
+ """
162
+ raise NotImplementedError
163
+
135
164
def save (self , fname : str ) -> None :
136
165
"""
137
166
Saves inference data of the model.
@@ -151,9 +180,11 @@ def save(self, fname: str) -> None:
151
180
>>> name = './mymodel.nc'
152
181
>>> model.save(name)
153
182
"""
154
-
155
- file = Path (str (fname ))
156
- self .idata .to_netcdf (file )
183
+ if self .idata is not None and "fit_data" in self .idata :
184
+ file = Path (str (fname ))
185
+ self .idata .to_netcdf (file )
186
+ else :
187
+ raise RuntimeError ("The model hasn't been fit yet, call .fit() first" )
157
188
158
189
@classmethod
159
190
def load (cls , fname : str ):
@@ -191,7 +222,7 @@ def load(cls, fname: str):
191
222
data = idata .fit_data .to_dataframe (),
192
223
)
193
224
model_builder .idata = idata
194
- model_builder .build ( )
225
+ model_builder .build_model ( model_builder . data , model_builder . model_config )
195
226
if model_builder .id != idata .attrs ["id" ]:
196
227
raise ValueError (
197
228
f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ cls ._model_type } '"
@@ -200,7 +231,12 @@ def load(cls, fname: str):
200
231
return model_builder
201
232
202
233
def fit (
203
- self , data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None
234
+ self ,
235
+ progressbar : bool = True ,
236
+ random_seed : RandomState = None ,
237
+ data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
238
+ * args : Any ,
239
+ ** kwargs : Any ,
204
240
) -> az .InferenceData :
205
241
"""
206
242
Fit a model using the data passed as a parameter.
@@ -227,20 +263,22 @@ def fit(
227
263
# If a new data was provided, assign it to the model
228
264
if data is not None :
229
265
self .data = data
230
-
231
- self .build ()
232
- self ._data_setter (data )
233
-
266
+ self .model_data , model_config , sampler_config = self .create_sample_input (data = self .data )
267
+ if self .model_config is None :
268
+ self .model_config = model_config
269
+ if self .sampler_config is None :
270
+ self .sampler_config = sampler_config
271
+ self .build_model (self .model_data , self .model_config )
234
272
with self .model :
235
- self .idata = pm .sample (** self .sampler_config )
273
+ self .idata = pm .sample (** self .sampler_config , ** kwargs )
236
274
self .idata .extend (pm .sample_prior_predictive ())
237
275
self .idata .extend (pm .sample_posterior_predictive (self .idata ))
238
276
239
277
self .idata .attrs ["id" ] = self .id
240
278
self .idata .attrs ["model_type" ] = self ._model_type
241
279
self .idata .attrs ["version" ] = self .version
242
280
self .idata .attrs ["sampler_config" ] = json .dumps (self .sampler_config )
243
- self .idata .attrs ["model_config" ] = json .dumps (self .model_config )
281
+ self .idata .attrs ["model_config" ] = json .dumps (self ._serializable_model_config )
244
282
self .idata .add_groups (fit_data = self .data .to_xarray ())
245
283
return self .idata
246
284
@@ -351,6 +389,19 @@ def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[st
351
389
352
390
return post_pred_dict
353
391
392
+ @property
393
+ @abstractmethod
394
+ def _serializable_model_config (self ) -> Dict [str , Union [int , float , Dict ]]:
395
+ """
396
+ Converts non-serializable values from model_config to their serializable reversable equivalent.
397
+ Data types like pandas DataFrame, Series or datetime aren't JSON serializable,
398
+ so in order to save the model they need to be formatted.
399
+
400
+ Returns
401
+ -------
402
+ model_config: dict
403
+ """
404
+
354
405
@property
355
406
def id (self ) -> str :
356
407
"""
0 commit comments