Skip to content

Commit e38da06

Browse files
Add build_model abstractmethod to ModelBuilder (#142)
* adaptations to integrate with mmm * adapted model_config and descriptions * fixed ModuleNotFoundError from build * small tweaks to make mmm tests work smoother * new test for save, fit allows custom configs * updating create_sample_input example
1 parent 5f1c2bb commit e38da06

File tree

2 files changed

+98
-37
lines changed

2 files changed

+98
-37
lines changed

pymc_experimental/model_builder.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
import json
1818
from abc import abstractmethod
1919
from pathlib import Path
20-
from typing import Dict, Union
20+
from typing import Any, Dict, Union
2121

2222
import arviz as az
2323
import numpy as np
2424
import pandas as pd
2525
import pymc as pm
26+
from pymc.util import RandomState
2627

2728

2829
class ModelBuilder:
@@ -100,7 +101,7 @@ def _data_setter(
100101
@abstractmethod
101102
def create_sample_input():
102103
"""
103-
Needs to be implemented by the user in the inherited class.
104+
Needs to be implemented by the user in the child class.
104105
Returns examples for data, model_config, sampler_config.
105106
This is useful for understanding the required
106107
data structures for the user model.
@@ -114,12 +115,15 @@ def create_sample_input():
114115
>>> data = pd.DataFrame({'input': x, 'output': y})
115116
116117
>>> model_config = {
117-
>>> 'a_loc': 7,
118-
>>> 'a_scale': 3,
119-
>>> 'b_loc': 5,
120-
>>> 'b_scale': 3,
121-
>>> 'obs_error': 2,
122-
>>> }
118+
>>> 'a' : {
119+
>>> 'loc': 7,
120+
>>> 'scale' : 3
121+
>>> },
122+
>>> 'b' : {
123+
>>> 'loc': 3,
124+
>>> 'scale': 5
125+
>>> }
126+
>>> 'obs_error': 2
123127
124128
>>> sampler_config = {
125129
>>> 'draws': 1_000,
@@ -132,6 +136,31 @@ def create_sample_input():
132136

133137
raise NotImplementedError
134138

139+
@abstractmethod
140+
def build_model(
141+
model_data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]],
142+
model_config: Dict[str, Union[int, float, Dict]],
143+
) -> None:
144+
"""
145+
Needs to be implemented by the user in the child class.
146+
Creates an instance of pm.Model based on provided model_data and model_config, and
147+
attaches it to self.
148+
149+
Required Parameters
150+
----------
151+
model_data - preformated data that is going to be used in the model.
152+
For efficiency reasons it should contain only the necesary data columns, not entire available
153+
dataset since it's going to be encoded into data used to recreate the model.
154+
model_config - dictionary where keys are strings representing names of parameters of the model, values are
155+
dictionaries of parameters needed for creating model parameters (see example in create_model_input)
156+
157+
Returns:
158+
----------
159+
None
160+
161+
"""
162+
raise NotImplementedError
163+
135164
def save(self, fname: str) -> None:
136165
"""
137166
Saves inference data of the model.
@@ -151,9 +180,11 @@ def save(self, fname: str) -> None:
151180
>>> name = './mymodel.nc'
152181
>>> model.save(name)
153182
"""
154-
155-
file = Path(str(fname))
156-
self.idata.to_netcdf(file)
183+
if self.idata is not None and "fit_data" in self.idata:
184+
file = Path(str(fname))
185+
self.idata.to_netcdf(file)
186+
else:
187+
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
157188

158189
@classmethod
159190
def load(cls, fname: str):
@@ -191,7 +222,7 @@ def load(cls, fname: str):
191222
data=idata.fit_data.to_dataframe(),
192223
)
193224
model_builder.idata = idata
194-
model_builder.build()
225+
model_builder.build_model(model_builder.data, model_builder.model_config)
195226
if model_builder.id != idata.attrs["id"]:
196227
raise ValueError(
197228
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
@@ -200,7 +231,12 @@ def load(cls, fname: str):
200231
return model_builder
201232

202233
def fit(
203-
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
234+
self,
235+
progressbar: bool = True,
236+
random_seed: RandomState = None,
237+
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
238+
*args: Any,
239+
**kwargs: Any,
204240
) -> az.InferenceData:
205241
"""
206242
Fit a model using the data passed as a parameter.
@@ -227,20 +263,22 @@ def fit(
227263
# If a new data was provided, assign it to the model
228264
if data is not None:
229265
self.data = data
230-
231-
self.build()
232-
self._data_setter(data)
233-
266+
self.model_data, model_config, sampler_config = self.create_sample_input(data=self.data)
267+
if self.model_config is None:
268+
self.model_config = model_config
269+
if self.sampler_config is None:
270+
self.sampler_config = sampler_config
271+
self.build_model(self.model_data, self.model_config)
234272
with self.model:
235-
self.idata = pm.sample(**self.sampler_config)
273+
self.idata = pm.sample(**self.sampler_config, **kwargs)
236274
self.idata.extend(pm.sample_prior_predictive())
237275
self.idata.extend(pm.sample_posterior_predictive(self.idata))
238276

239277
self.idata.attrs["id"] = self.id
240278
self.idata.attrs["model_type"] = self._model_type
241279
self.idata.attrs["version"] = self.version
242280
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
243-
self.idata.attrs["model_config"] = json.dumps(self.model_config)
281+
self.idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
244282
self.idata.add_groups(fit_data=self.data.to_xarray())
245283
return self.idata
246284

@@ -351,6 +389,19 @@ def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[st
351389

352390
return post_pred_dict
353391

392+
@property
393+
@abstractmethod
394+
def _serializable_model_config(self) -> Dict[str, Union[int, float, Dict]]:
395+
"""
396+
Converts non-serializable values from model_config to their serializable reversable equivalent.
397+
Data types like pandas DataFrame, Series or datetime aren't JSON serializable,
398+
so in order to save the model they need to be formatted.
399+
400+
Returns
401+
-------
402+
model_config: dict
403+
"""
404+
354405
@property
355406
def id(self) -> str:
356407
"""

pymc_experimental/tests/test_model_builder.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,26 @@ class test_ModelBuilder(ModelBuilder):
2828
_model_type = "LinearModel"
2929
version = "0.1"
3030

31-
def build(self):
32-
31+
def build_model(self, model_data, model_config):
3332
with pm.Model() as self.model:
34-
if self.data is not None:
35-
x = pm.MutableData("x", self.data["input"].values)
36-
y_data = pm.MutableData("y_data", self.data["output"].values)
33+
if model_data is not None:
34+
x = pm.MutableData("x", model_data["input"].values)
35+
y_data = pm.MutableData("y_data", model_data["output"].values)
3736

3837
# prior parameters
39-
a_loc = self.model_config["a_loc"]
40-
a_scale = self.model_config["a_scale"]
41-
b_loc = self.model_config["b_loc"]
42-
b_scale = self.model_config["b_scale"]
43-
obs_error = self.model_config["obs_error"]
38+
a_loc = model_config["a"]["loc"]
39+
a_scale = model_config["a"]["scale"]
40+
b_loc = model_config["b"]["loc"]
41+
b_scale = model_config["b"]["scale"]
42+
obs_error = model_config["obs_error"]
4443

4544
# priors
4645
a = pm.Normal("a", a_loc, sigma=a_scale)
4746
b = pm.Normal("b", b_loc, sigma=b_scale)
4847
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
4948

5049
# observed data
51-
if self.data is not None:
50+
if model_data is not None:
5251
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data)
5352

5453
def _data_setter(self, data: pd.DataFrame):
@@ -57,18 +56,20 @@ def _data_setter(self, data: pd.DataFrame):
5756
if "output" in data.columns:
5857
pm.set_data({"y_data": data["output"].values})
5958

59+
@property
60+
def _serializable_model_config(self):
61+
return self.model_config
62+
6063
@classmethod
61-
def create_sample_input(self):
64+
def create_sample_input(self, data=None):
6265
x = np.linspace(start=0, stop=1, num=100)
6366
y = 5 * x + 3
6467
y = y + np.random.normal(0, 1, len(x))
6568
data = pd.DataFrame({"input": x, "output": y})
6669

6770
model_config = {
68-
"a_loc": 0,
69-
"a_scale": 10,
70-
"b_loc": 0,
71-
"b_scale": 10,
71+
"a": {"loc": 0, "scale": 10},
72+
"b": {"loc": 0, "scale": 10},
7273
"obs_error": 2,
7374
}
7475

@@ -94,7 +95,16 @@ def initial_build_and_fit(check_idata=True) -> ModelBuilder:
9495
return model_builder
9596

9697

97-
def test_empty_model_config():
98+
def test_save_without_fit_raises_runtime_error():
99+
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
100+
model_builder = test_ModelBuilder(
101+
model_config=model_config, sampler_config=sampler_config, data=data
102+
)
103+
with pytest.raises(RuntimeError):
104+
model_builder.save("saved_model")
105+
106+
107+
def test_empty_sampler_config_fit():
98108
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
99109
sampler_config = {}
100110
model_builder = test_ModelBuilder(
@@ -105,7 +115,7 @@ def test_empty_model_config():
105115
assert "posterior" in model_builder.idata.groups()
106116

107117

108-
def test_empty_model_config():
118+
def test_empty_model_config_fit():
109119
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
110120
model_config = {}
111121
model_builder = test_ModelBuilder(

0 commit comments

Comments
 (0)