Skip to content

Commit e62f924

Browse files
adapted model_config and descriptions
1 parent 455de05 commit e62f924

File tree

2 files changed

+58
-27
lines changed

2 files changed

+58
-27
lines changed

pymc_experimental/model_builder.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _data_setter(
102102
@abstractmethod
103103
def create_sample_input():
104104
"""
105-
Needs to be implemented by the user in the inherited class.
105+
Needs to be implemented by the user in the child class.
106106
Returns examples for data, model_config, sampler_config.
107107
This is useful for understanding the required
108108
data structures for the user model.
@@ -116,12 +116,15 @@ def create_sample_input():
116116
>>> data = pd.DataFrame({'input': x, 'output': y})
117117
118118
>>> model_config = {
119-
>>> 'a_loc': 7,
120-
>>> 'a_scale': 3,
121-
>>> 'b_loc': 5,
122-
>>> 'b_scale': 3,
123-
>>> 'obs_error': 2,
124-
>>> }
119+
>>> 'a' : {
120+
>>> 'a_loc': 7,
121+
>>> 'a_scale' : 3
122+
>>> },
123+
>>> 'b' : {
124+
>>> 'b_loc': 3,
125+
>>> 'b_scale': 5
126+
>>> }
127+
>>> 'obs_error': 2
125128
126129
>>> sampler_config = {
127130
>>> 'draws': 1_000,
@@ -134,6 +137,31 @@ def create_sample_input():
134137

135138
raise NotImplementedError
136139

140+
@abstractmethod
141+
def build_model(
142+
model_data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]],
143+
model_config: Dict[str, Union[int, float, Dict]],
144+
) -> None:
145+
"""
146+
Needs to be implemented by the user in the child class.
147+
Creates an instance of pm.Model based on provided model_data and model_config, and
148+
attaches it to self.
149+
150+
Required Parameters
151+
----------
152+
model_data - preformated data that is going to be used in the model.
153+
For efficiency reasons it should contain only the necesary data columns, not entire available
154+
dataset since it's going to be encoded into data used to recreate the model.
155+
model_config - dictionary where keys are strings representing names of parameters of the model, values are
156+
dictionaries of parameters needed for creating model parameters (see example in create_model_input)
157+
158+
Returns:
159+
----------
160+
None
161+
162+
"""
163+
raise NotImplementedError
164+
137165
def save(self, fname: str) -> None:
138166
"""
139167
Saves inference data of the model.
@@ -193,7 +221,7 @@ def load(cls, fname: str):
193221
data=idata.fit_data.to_dataframe(),
194222
)
195223
model_builder.idata = idata
196-
model_builder.build_model()
224+
model_builder.idata = model_builder.fit()
197225
if model_builder.id != idata.attrs["id"]:
198226
raise ValueError(
199227
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
@@ -234,9 +262,11 @@ def fit(
234262
# If a new data was provided, assign it to the model
235263
if data is not None:
236264
self.data = data
237-
self.model_data, self.model_config = self.create_sample_input(data=self.data)
265+
self.model_data, self.model_config, self.sampler_config = self.create_sample_input(
266+
data=self.data
267+
)
238268
self.build_model(self.model_data, self.model_config)
239-
self._data_setter(self.data)
269+
self._data_setter(self.model_data)
240270

241271
with self.model:
242272
self.idata = pm.sample(**self.sampler_config)
@@ -248,7 +278,7 @@ def fit(
248278
self.idata.attrs["version"] = self.version
249279
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
250280
self.idata.attrs["model_config"] = json.dumps(self.serializable_model_config)
251-
self.idata.add_groups(fit_data=self.data.to_xarray())
281+
self.idata.add_groups(fit_data=self.model_data.to_xarray())
252282
return self.idata
253283

254284
def predict(

pymc_experimental/tests/test_model_builder.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,26 @@ class test_ModelBuilder(ModelBuilder):
2828
_model_type = "LinearModel"
2929
version = "0.1"
3030

31-
def build(self):
32-
31+
def build_model(self, model_data, model_config):
3332
with pm.Model() as self.model:
34-
if self.data is not None:
35-
x = pm.MutableData("x", self.data["input"].values)
36-
y_data = pm.MutableData("y_data", self.data["output"].values)
33+
if model_data is not None:
34+
x = pm.MutableData("x", model_data["input"].values)
35+
y_data = pm.MutableData("y_data", model_data["output"].values)
3736

3837
# prior parameters
39-
a_loc = self.model_config["a_loc"]
40-
a_scale = self.model_config["a_scale"]
41-
b_loc = self.model_config["b_loc"]
42-
b_scale = self.model_config["b_scale"]
43-
obs_error = self.model_config["obs_error"]
38+
a_loc = model_config["a"]["loc"]
39+
a_scale = model_config["a"]["scale"]
40+
b_loc = model_config["b"]["loc"]
41+
b_scale = model_config["b"]["scale"]
42+
obs_error = model_config["obs_error"]
4443

4544
# priors
4645
a = pm.Normal("a", a_loc, sigma=a_scale)
4746
b = pm.Normal("b", b_loc, sigma=b_scale)
4847
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
4948

5049
# observed data
51-
if self.data is not None:
50+
if model_data is not None:
5251
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data)
5352

5453
def _data_setter(self, data: pd.DataFrame):
@@ -57,18 +56,20 @@ def _data_setter(self, data: pd.DataFrame):
5756
if "output" in data.columns:
5857
pm.set_data({"y_data": data["output"].values})
5958

59+
@property
60+
def serializable_model_config(self):
61+
return self.model_config
62+
6063
@classmethod
61-
def create_sample_input(self):
64+
def create_sample_input(self, data=None):
6265
x = np.linspace(start=0, stop=1, num=100)
6366
y = 5 * x + 3
6467
y = y + np.random.normal(0, 1, len(x))
6568
data = pd.DataFrame({"input": x, "output": y})
6669

6770
model_config = {
68-
"a_loc": 0,
69-
"a_scale": 10,
70-
"b_loc": 0,
71-
"b_scale": 10,
71+
"a": {"loc": 0, "scale": 10},
72+
"b": {"loc": 0, "scale": 10},
7273
"obs_error": 2,
7374
}
7475

0 commit comments

Comments
 (0)