Skip to content

Commit 9c481aa

Browse files
data changed to required, changed sampler_config to optional. Added extend_idata to prediction functions, adapted test file
1 parent 63f6be5 commit 9c481aa

File tree

2 files changed

+41
-18
lines changed

2 files changed

+41
-18
lines changed

pymc_experimental/model_builder.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ class ModelBuilder:
3737
def __init__(
3838
self,
3939
model_config: Dict,
40-
sampler_config: Dict,
41-
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
40+
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]],
41+
sampler_config: Dict = None,
4242
):
4343
"""
4444
Initializes model configuration and sampler configuration for the model
@@ -47,10 +47,10 @@ def __init__(
4747
----------
4848
model_config : Dictionary
4949
dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method.
50-
sampler_config : Dictionary
51-
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
5250
data : Dictionary
5351
It is the data we need to train the model on.
52+
sampler_config : Dictionary
53+
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
5454
Examples
5555
--------
5656
>>> class LinearModel(ModelBuilder):
@@ -60,9 +60,11 @@ def __init__(
6060

6161
super().__init__()
6262
self.model_config = model_config # parameters for priors etc.
63-
self.sample_config = sampler_config # parameters for sampling
64-
self.idata = None # inference data object
63+
self.sampler_config = sampler_config # parameters for sampling
6564
self.data = data
65+
self.idata = (
66+
None # inference data object placeholder, idata is generated during build execution
67+
)
6668

6769
def build(self) -> None:
6870
"""
@@ -188,16 +190,20 @@ def load(cls, fname: str):
188190

189191
filepath = Path(str(fname))
190192
idata = az.from_netcdf(filepath)
193+
if "sampler_config" in idata.attrs:
194+
sampler_config = json.loads(idata.attrs["sampler_config"])
195+
else:
196+
sampler_config = None
191197
model_builder = cls(
192-
json.loads(idata.attrs["model_config"]),
193-
json.loads(idata.attrs["sampler_config"]),
194-
idata.fit_data.to_dataframe(),
198+
model_config=json.loads(idata.attrs["model_config"]),
199+
sampler_config=sampler_config,
200+
data=idata.fit_data.to_dataframe(),
195201
)
196202
model_builder.idata = idata
197203
model_builder.build()
198204
if model_builder.id != idata.attrs["id"]:
199205
raise ValueError(
200-
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'"
206+
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
201207
)
202208

203209
return model_builder
@@ -206,7 +212,7 @@ def fit(
206212
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
207213
) -> az.InferenceData:
208214
"""
209-
As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
215+
Fit a model using the data passed as a parameter.
210216
Sets attrs to inference data of the model.
211217
212218
Parameter
@@ -233,21 +239,26 @@ def fit(
233239
self._data_setter(data)
234240

235241
with self.model:
236-
self.idata = pm.sample(**self.sample_config)
242+
if self.sampler_config is not None:
243+
self.idata = pm.sample(**self.sampler_config)
244+
else:
245+
self.idata = pm.sample()
237246
self.idata.extend(pm.sample_prior_predictive())
238247
self.idata.extend(pm.sample_posterior_predictive(self.idata))
239248

240249
self.idata.attrs["id"] = self.id
241250
self.idata.attrs["model_type"] = self._model_type
242251
self.idata.attrs["version"] = self.version
243-
self.idata.attrs["sampler_config"] = json.dumps(self.sample_config)
252+
if self.sampler_config is not None:
253+
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
244254
self.idata.attrs["model_config"] = json.dumps(self.model_config)
245255
self.idata.add_groups(fit_data=self.data.to_xarray())
246256
return self.idata
247257

248258
def predict(
249259
self,
250260
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
261+
extend_idata: bool = True,
251262
) -> dict:
252263
"""
253264
Uses model to predict on unseen data and return point prediction of all the samples
@@ -256,6 +267,8 @@ def predict(
256267
---------
257268
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
258269
It is the data we need to make prediction on using the model.
270+
extend_idata : Boolean determining whether the predictions should be added to inference data object.
271+
Defaults to True.
259272
260273
Returns
261274
-------
@@ -277,7 +290,8 @@ def predict(
277290

278291
with self.model: # sample with new input data
279292
post_pred = pm.sample_posterior_predictive(self.idata)
280-
293+
if extend_idata:
294+
self.idata.extend(post_pred)
281295
# reshape output
282296
post_pred = self._extract_samples(post_pred)
283297
for key in post_pred:
@@ -288,6 +302,7 @@ def predict(
288302
def predict_posterior(
289303
self,
290304
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
305+
extend_idata: bool = True,
291306
) -> Dict[str, np.array]:
292307
"""
293308
Uses model to predict samples on unseen data.
@@ -296,8 +311,8 @@ def predict_posterior(
296311
---------
297312
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
298313
It is the data we need to make prediction on using the model.
299-
point_estimate : bool
300-
Adds point like estimate used as mean passed as
314+
extend_idata : Boolean determining whether the predictions should be added to inference data object.
315+
Defaults to True.
301316
302317
Returns
303318
-------
@@ -319,6 +334,8 @@ def predict_posterior(
319334

320335
with self.model: # sample with new input data
321336
post_pred = pm.sample_posterior_predictive(self.idata)
337+
if extend_idata:
338+
self.idata.extend(post_pred)
322339

323340
# reshape output
324341
post_pred = self._extract_samples(post_pred)
@@ -359,5 +376,4 @@ def id(self) -> str:
359376
hasher.update(str(self.model_config.values()).encode())
360377
hasher.update(self.version.encode())
361378
hasher.update(self._model_type.encode())
362-
# hasher.update(str(self.sample_config.values()).encode())
363379
return hasher.hexdigest()[:16]

pymc_experimental/tests/test_model_builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,17 @@ class test_ModelBuilder(ModelBuilder):
2828
_model_type = "LinearModel"
2929
version = "0.1"
3030

31-
def build_model(self, model_instance: ModelBuilder, model_config: dict, data: dict = None):
31+
def build_model(
32+
self,
33+
model_instance: ModelBuilder,
34+
model_config: dict,
35+
data: dict = None,
36+
sampler_config: dict = None,
37+
):
3238
model_instance.model_config = model_config
3339
model_instance.data = data
3440
self.model_config = model_config
41+
self.sampler_config = sampler_config
3542
self.data = data
3643

3744
with pm.Model() as model_instance.model:

0 commit comments

Comments
 (0)