@@ -28,8 +28,6 @@ class ModelBuilder:
28
28
"""
29
29
ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models
30
30
and help with deployment.
31
-
32
- Extends the pymc.Model class.
33
31
"""
34
32
35
33
_model_type = "BaseClass"
@@ -65,7 +63,7 @@ def __init__(
65
63
self .idata = None # inference data object
66
64
self .data = data
67
65
68
- def build (self ):
66
+ def build (self ) -> None :
69
67
"""
70
68
Builds the defined model.
71
69
"""
@@ -136,7 +134,7 @@ def create_sample_input():
136
134
137
135
raise NotImplementedError
138
136
139
- def save (self , fname ) :
137
+ def save (self , fname : str ) -> None :
140
138
"""
141
139
Saves inference data of the model.
142
140
@@ -160,8 +158,9 @@ def save(self, fname):
160
158
self .idata .to_netcdf (file )
161
159
162
160
@classmethod
163
- def load (cls , fname ):
161
+ def load (cls , fname : str ):
164
162
"""
163
+ Creates a ModelBuilder instance from a file,
165
164
Loads inference data for the model.
166
165
167
166
Parameters
@@ -171,7 +170,7 @@ def load(cls, fname):
171
170
172
171
Returns
173
172
-------
174
- Returns an instance of pm.Model, that is loaded from local data .
173
+ Returns an instance of ModelBuilder .
175
174
176
175
Raises
177
176
------
@@ -188,21 +187,23 @@ def load(cls, fname):
188
187
189
188
filepath = Path (str (fname ))
190
189
idata = az .from_netcdf (filepath )
191
- self = cls (
190
+ model_builder = cls (
192
191
dict (zip (idata .attrs ["model_config_keys" ], idata .attrs ["model_config_values" ])),
193
192
dict (zip (idata .attrs ["sample_config_keys" ], idata .attrs ["sample_config_values" ])),
194
193
idata .fit_data .to_dataframe (),
195
194
)
196
- self .idata = idata
197
- self .build ()
198
- if self .id != idata .attrs ["id" ]:
195
+ model_builder .idata = idata
196
+ model_builder .build ()
197
+ if model_builder .id != idata .attrs ["id" ]:
199
198
raise ValueError (
200
199
f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ self ._model_type } '"
201
200
)
202
201
203
- return self . model
202
+ return model_builder
204
203
205
- def fit (self , data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ):
204
+ def fit (
205
+ self , data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None
206
+ ) -> az .InferenceData :
206
207
"""
207
208
As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
208
209
Sets attrs to inference data of the model.
@@ -248,7 +249,7 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
248
249
def predict (
249
250
self ,
250
251
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
251
- ):
252
+ ) -> dict :
252
253
"""
253
254
Uses model to predict on unseen data and return point prediction of all the samples
254
255
@@ -288,7 +289,7 @@ def predict(
288
289
def predict_posterior (
289
290
self ,
290
291
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
291
- ):
292
+ ) -> Dict [ str , np . array ] :
292
293
"""
293
294
Uses model to predict samples on unseen data.
294
295
0 commit comments