@@ -102,7 +102,7 @@ def _data_setter(
102
102
@abstractmethod
103
103
def create_sample_input ():
104
104
"""
105
- Needs to be implemented by the user in the inherited class.
105
+ Needs to be implemented by the user in the child class.
106
106
Returns examples for data, model_config, sampler_config.
107
107
This is useful for understanding the required
108
108
data structures for the user model.
@@ -116,12 +116,15 @@ def create_sample_input():
116
116
>>> data = pd.DataFrame({'input': x, 'output': y})
117
117
118
118
>>> model_config = {
119
- >>> 'a_loc': 7,
120
- >>> 'a_scale': 3,
121
- >>> 'b_loc': 5,
122
- >>> 'b_scale': 3,
123
- >>> 'obs_error': 2,
124
- >>> }
119
+ >>> 'a' : {
120
+ >>> 'a_loc': 7,
121
+ >>> 'a_scale' : 3
122
+ >>> },
123
+ >>> 'b' : {
124
+ >>> 'b_loc': 3,
125
+ >>> 'b_scale': 5
126
+ >>> }
127
+ >>> 'obs_error': 2
125
128
126
129
>>> sampler_config = {
127
130
>>> 'draws': 1_000,
@@ -134,6 +137,31 @@ def create_sample_input():
134
137
135
138
raise NotImplementedError
136
139
140
+ @abstractmethod
141
+ def build_model (
142
+ model_data : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]],
143
+ model_config : Dict [str , Union [int , float , Dict ]],
144
+ ) -> None :
145
+ """
146
+ Needs to be implemented by the user in the child class.
147
+ Creates an instance of pm.Model based on provided model_data and model_config, and
148
+ attaches it to self.
149
+
150
+ Required Parameters
151
+ ----------
152
+ model_data - preformated data that is going to be used in the model.
153
+ For efficiency reasons it should contain only the necesary data columns, not entire available
154
+ dataset since it's going to be encoded into data used to recreate the model.
155
+ model_config - dictionary where keys are strings representing names of parameters of the model, values are
156
+ dictionaries of parameters needed for creating model parameters (see example in create_model_input)
157
+
158
+ Returns:
159
+ ----------
160
+ None
161
+
162
+ """
163
+ raise NotImplementedError
164
+
137
165
def save (self , fname : str ) -> None :
138
166
"""
139
167
Saves inference data of the model.
@@ -193,7 +221,7 @@ def load(cls, fname: str):
193
221
data = idata .fit_data .to_dataframe (),
194
222
)
195
223
model_builder .idata = idata
196
- model_builder .build_model ()
224
+ model_builder .idata = model_builder . fit ()
197
225
if model_builder .id != idata .attrs ["id" ]:
198
226
raise ValueError (
199
227
f"The file '{ fname } ' does not contain an inference data of the same model or configuration as '{ cls ._model_type } '"
@@ -234,9 +262,11 @@ def fit(
234
262
# If a new data was provided, assign it to the model
235
263
if data is not None :
236
264
self .data = data
237
- self .model_data , self .model_config = self .create_sample_input (data = self .data )
265
+ self .model_data , self .model_config , self .sampler_config = self .create_sample_input (
266
+ data = self .data
267
+ )
238
268
self .build_model (self .model_data , self .model_config )
239
- self ._data_setter (self .data )
269
+ self ._data_setter (self .model_data )
240
270
241
271
with self .model :
242
272
self .idata = pm .sample (** self .sampler_config )
@@ -248,7 +278,7 @@ def fit(
248
278
self .idata .attrs ["version" ] = self .version
249
279
self .idata .attrs ["sampler_config" ] = json .dumps (self .sampler_config )
250
280
self .idata .attrs ["model_config" ] = json .dumps (self .serializable_model_config )
251
- self .idata .add_groups (fit_data = self .data .to_xarray ())
281
+ self .idata .add_groups (fit_data = self .model_data .to_xarray ())
252
282
return self .idata
253
283
254
284
def predict (
0 commit comments