Skip to content

Commit 831a894

Browse files
Model Builder refactoring (#131)
* new branch after failed rebasing * making model config optional, implementing requested changes * removed super().__init__() * removed build method, used direct calls to build_model instead * refactoring build_model, renaming to 'build' * renamed build_model to 'build'
1 parent f4962be commit 831a894

File tree

2 files changed

+111
-81
lines changed

2 files changed

+111
-81
lines changed

pymc_experimental/model_builder.py

+56-49
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import hashlib
1717
import json
18+
from abc import abstractmethod
1819
from pathlib import Path
1920
from typing import Dict, Union
2021

@@ -24,56 +25,51 @@
2425
import pymc as pm
2526

2627

27-
class ModelBuilder(pm.Model):
28+
class ModelBuilder:
2829
"""
2930
ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models
3031
and help with deployment.
31-
32-
Extends the pymc.Model class.
3332
"""
3433

3534
_model_type = "BaseClass"
3635
version = "None"
3736

3837
def __init__(
3938
self,
40-
model_config: Dict,
41-
sampler_config: Dict,
42-
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
39+
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]],
40+
model_config: Dict = None,
41+
sampler_config: Dict = None,
4342
):
4443
"""
4544
Initializes model configuration and sampler configuration for the model
4645
4746
Parameters
4847
----------
49-
model_config : Dictionary
48+
model_config : Dictionary, optional
5049
dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method.
51-
sampler_config : Dictionary
52-
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
53-
data : Dictionary
50+
data : Dictionary, required
5451
It is the data we need to train the model on.
52+
sampler_config : Dictionary, optional
53+
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
5554
Examples
5655
--------
5756
>>> class LinearModel(ModelBuilder):
5857
>>> ...
5958
>>> model = LinearModel(model_config, sampler_config)
6059
"""
6160

62-
super().__init__()
61+
if sampler_config is None:
62+
sampler_config = {}
63+
if model_config is None:
64+
model_config = {}
6365
self.model_config = model_config # parameters for priors etc.
64-
self.sample_config = sampler_config # parameters for sampling
65-
self.idata = None # inference data object
66+
self.sampler_config = sampler_config # parameters for sampling
6667
self.data = data
67-
self.build()
68-
69-
def build(self):
70-
"""
71-
Builds the defined model.
72-
"""
73-
74-
with self:
75-
self.build_model(self.model_config, self.data)
68+
self.idata = (
69+
None # inference data object placeholder, idata is generated during build execution
70+
)
7671

72+
@abstractmethod
7773
def _data_setter(
7874
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], x_only: bool = True
7975
):
@@ -100,8 +96,9 @@ def _data_setter(
10096

10197
raise NotImplementedError
10298

103-
@classmethod
104-
def create_sample_input(cls):
99+
@staticmethod
100+
@abstractmethod
101+
def create_sample_input():
105102
"""
106103
Needs to be implemented by the user in the inherited class.
107104
Returns examples for data, model_config, sampler_config.
@@ -135,7 +132,7 @@ def create_sample_input(cls):
135132

136133
raise NotImplementedError
137134

138-
def save(self, fname):
135+
def save(self, fname: str) -> None:
139136
"""
140137
Saves inference data of the model.
141138
@@ -159,8 +156,9 @@ def save(self, fname):
159156
self.idata.to_netcdf(file)
160157

161158
@classmethod
162-
def load(cls, fname):
159+
def load(cls, fname: str):
163160
"""
161+
Creates a ModelBuilder instance from a file,
164162
Loads inference data for the model.
165163
166164
Parameters
@@ -170,7 +168,7 @@ def load(cls, fname):
170168
171169
Returns
172170
-------
173-
Returns the inference data that is loaded from local system.
171+
Returns an instance of ModelBuilder.
174172
175173
Raises
176174
------
@@ -187,22 +185,25 @@ def load(cls, fname):
187185

188186
filepath = Path(str(fname))
189187
idata = az.from_netcdf(filepath)
190-
self = cls(
191-
json.loads(idata.attrs["model_config"]),
192-
json.loads(idata.attrs["sampler_config"]),
193-
idata.fit_data.to_dataframe(),
188+
model_builder = cls(
189+
model_config=json.loads(idata.attrs["model_config"]),
190+
sampler_config=json.loads(idata.attrs["sampler_config"]),
191+
data=idata.fit_data.to_dataframe(),
194192
)
195-
self.idata = idata
196-
if self.id != idata.attrs["id"]:
193+
model_builder.idata = idata
194+
model_builder.build()
195+
if model_builder.id != idata.attrs["id"]:
197196
raise ValueError(
198-
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'"
197+
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
199198
)
200199

201-
return self
200+
return model_builder
202201

203-
def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
202+
def fit(
203+
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
204+
) -> az.InferenceData:
204205
"""
205-
As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
206+
Fit a model using the data passed as a parameter.
206207
Sets attrs to inference data of the model.
207208
208209
Parameter
@@ -223,37 +224,40 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
223224
Initializing NUTS using jitter+adapt_diag...
224225
"""
225226

227+
# If a new data was provided, assign it to the model
226228
if data is not None:
227229
self.data = data
228-
self._data_setter(data)
229230

230-
if self.basic_RVs == []:
231-
self.build()
231+
self.build()
232+
self._data_setter(data)
232233

233-
with self:
234-
self.idata = pm.sample(**self.sample_config)
234+
with self.model:
235+
self.idata = pm.sample(**self.sampler_config)
235236
self.idata.extend(pm.sample_prior_predictive())
236237
self.idata.extend(pm.sample_posterior_predictive(self.idata))
237238

238239
self.idata.attrs["id"] = self.id
239240
self.idata.attrs["model_type"] = self._model_type
240241
self.idata.attrs["version"] = self.version
241-
self.idata.attrs["sampler_config"] = json.dumps(self.sample_config)
242+
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
242243
self.idata.attrs["model_config"] = json.dumps(self.model_config)
243244
self.idata.add_groups(fit_data=self.data.to_xarray())
244245
return self.idata
245246

246247
def predict(
247248
self,
248249
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
249-
):
250+
extend_idata: bool = True,
251+
) -> dict:
250252
"""
251253
Uses model to predict on unseen data and return point prediction of all the samples
252254
253255
Parameters
254256
---------
255257
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
256258
It is the data we need to make prediction on using the model.
259+
extend_idata : Boolean determining whether the predictions should be added to inference data object.
260+
Defaults to True.
257261
258262
Returns
259263
-------
@@ -275,7 +279,8 @@ def predict(
275279

276280
with self.model: # sample with new input data
277281
post_pred = pm.sample_posterior_predictive(self.idata)
278-
282+
if extend_idata:
283+
self.idata.extend(post_pred)
279284
# reshape output
280285
post_pred = self._extract_samples(post_pred)
281286
for key in post_pred:
@@ -286,16 +291,17 @@ def predict(
286291
def predict_posterior(
287292
self,
288293
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
289-
):
294+
extend_idata: bool = True,
295+
) -> Dict[str, np.array]:
290296
"""
291297
Uses model to predict samples on unseen data.
292298
293299
Parameters
294300
---------
295301
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
296302
It is the data we need to make prediction on using the model.
297-
point_estimate : bool
298-
Adds point like estimate used as mean passed as
303+
extend_idata : Boolean determining whether the predictions should be added to inference data object.
304+
Defaults to True.
299305
300306
Returns
301307
-------
@@ -317,6 +323,8 @@ def predict_posterior(
317323

318324
with self.model: # sample with new input data
319325
post_pred = pm.sample_posterior_predictive(self.idata)
326+
if extend_idata:
327+
self.idata.extend(post_pred)
320328

321329
# reshape output
322330
post_pred = self._extract_samples(post_pred)
@@ -357,5 +365,4 @@ def id(self) -> str:
357365
hasher.update(str(self.model_config.values()).encode())
358366
hasher.update(self.version.encode())
359367
hasher.update(self._model_type.encode())
360-
# hasher.update(str(self.sample_config.values()).encode())
361368
return hasher.hexdigest()[:16]

pymc_experimental/tests/test_model_builder.py

+55-32
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import hashlib
1716
import sys
1817
import tempfile
@@ -29,26 +28,28 @@ class test_ModelBuilder(ModelBuilder):
2928
_model_type = "LinearModel"
3029
version = "0.1"
3130

32-
def build_model(self, model_config, data=None):
33-
if data is not None:
34-
x = pm.MutableData("x", data["input"].values)
35-
y_data = pm.MutableData("y_data", data["output"].values)
31+
def build(self):
32+
33+
with pm.Model() as self.model:
34+
if self.data is not None:
35+
x = pm.MutableData("x", self.data["input"].values)
36+
y_data = pm.MutableData("y_data", self.data["output"].values)
3637

37-
# prior parameters
38-
a_loc = model_config["a_loc"]
39-
a_scale = model_config["a_scale"]
40-
b_loc = model_config["b_loc"]
41-
b_scale = model_config["b_scale"]
42-
obs_error = model_config["obs_error"]
38+
# prior parameters
39+
a_loc = self.model_config["a_loc"]
40+
a_scale = self.model_config["a_scale"]
41+
b_loc = self.model_config["b_loc"]
42+
b_scale = self.model_config["b_scale"]
43+
obs_error = self.model_config["obs_error"]
4344

44-
# priors
45-
a = pm.Normal("a", a_loc, sigma=a_scale)
46-
b = pm.Normal("b", b_loc, sigma=b_scale)
47-
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
45+
# priors
46+
a = pm.Normal("a", a_loc, sigma=a_scale)
47+
b = pm.Normal("b", b_loc, sigma=b_scale)
48+
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
4849

49-
# observed data
50-
if data is not None:
51-
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data)
50+
# observed data
51+
if self.data is not None:
52+
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data)
5253

5354
def _data_setter(self, data: pd.DataFrame):
5455
with self.model:
@@ -57,7 +58,7 @@ def _data_setter(self, data: pd.DataFrame):
5758
pm.set_data({"y_data": data["output"].values})
5859

5960
@classmethod
60-
def create_sample_input(cls):
61+
def create_sample_input(self):
6162
x = np.linspace(start=0, stop=1, num=100)
6263
y = 5 * x + 3
6364
y = y + np.random.normal(0, 1, len(x))
@@ -81,14 +82,36 @@ def create_sample_input(cls):
8182
return data, model_config, sampler_config
8283

8384
@staticmethod
84-
def initial_build_and_fit(check_idata=True):
85+
def initial_build_and_fit(check_idata=True) -> ModelBuilder:
8586
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
86-
model = test_ModelBuilder(model_config, sampler_config, data)
87-
model.fit()
87+
model_builder = test_ModelBuilder(
88+
model_config=model_config, sampler_config=sampler_config, data=data
89+
)
90+
model_builder.idata = model_builder.fit(data=data)
8891
if check_idata:
89-
assert model.idata is not None
90-
assert "posterior" in model.idata.groups()
91-
return model
92+
assert model_builder.idata is not None
93+
assert "posterior" in model_builder.idata.groups()
94+
return model_builder
95+
96+
97+
def test_empty_model_config():
98+
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
99+
sampler_config = {}
100+
model_builder = test_ModelBuilder(
101+
model_config=model_config, sampler_config=sampler_config, data=data
102+
)
103+
model_builder.idata = model_builder.fit(data=data)
104+
assert model_builder.idata is not None
105+
assert "posterior" in model_builder.idata.groups()
106+
107+
108+
def test_empty_model_config():
109+
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
110+
model_config = {}
111+
model_builder = test_ModelBuilder(
112+
model_config=model_config, sampler_config=sampler_config, data=data
113+
)
114+
assert model_builder.model_config == {}
92115

93116

94117
def test_fit():
@@ -105,16 +128,16 @@ def test_fit():
105128
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
106129
)
107130
def test_save_load():
108-
model = test_ModelBuilder.initial_build_and_fit(False)
131+
test_builder = test_ModelBuilder.initial_build_and_fit()
109132
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
110-
model.save(temp.name)
111-
model2 = test_ModelBuilder.load(temp.name)
112-
assert model.idata.groups() == model2.idata.groups()
133+
test_builder.save(temp.name)
134+
test_builder2 = test_ModelBuilder.load(temp.name)
135+
assert test_builder.idata.groups() == test_builder2.idata.groups()
113136

114137
x_pred = np.random.uniform(low=0, high=1, size=100)
115138
prediction_data = pd.DataFrame({"input": x_pred})
116-
pred1 = model.predict(prediction_data)
117-
pred2 = model2.predict(prediction_data)
139+
pred1 = test_builder.predict(prediction_data)
140+
pred2 = test_builder2.predict(prediction_data)
118141
assert pred1["y_model"].shape == pred2["y_model"].shape
119142
temp.close()
120143

@@ -163,7 +186,7 @@ def test_extract_samples():
163186

164187
def test_id():
165188
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
166-
model = test_ModelBuilder(model_config, sampler_config, data)
189+
model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config, data=data)
167190

168191
expected_id = hashlib.sha256(
169192
str(model_config.values()).encode() + model.version.encode() + model._model_type.encode()

0 commit comments

Comments
 (0)