Skip to content

Commit 2e4507d

Browse files
removing inheritance from pm.Model, adding instance of a model as a class property
1 parent 8cc3383 commit 2e4507d

File tree

2 files changed

+46
-40
lines changed

2 files changed

+46
-40
lines changed

pymc_experimental/model_builder.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import hashlib
17+
from abc import abstractmethod
1718
from pathlib import Path
1819
from typing import Dict, Union
1920

@@ -23,7 +24,7 @@
2324
import pymc as pm
2425

2526

26-
class ModelBuilder(pm.Model):
27+
class ModelBuilder:
2728
"""
2829
ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models
2930
and help with deployment.
@@ -69,9 +70,9 @@ def build(self):
6970
Builds the defined model.
7071
"""
7172

72-
with self:
73-
self.build_model(self.model_config, self.data)
73+
self.build_model(self.model_config, self.data)
7474

75+
@abstractmethod
7576
def _data_setter(
7677
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], x_only: bool = True
7778
):
@@ -98,8 +99,10 @@ def _data_setter(
9899

99100
raise NotImplementedError
100101

101-
@classmethod
102-
def create_sample_input(cls):
102+
# need a discussion if it's really needed.
103+
@staticmethod
104+
@abstractmethod
105+
def create_sample_input():
103106
"""
104107
Needs to be implemented by the user in the inherited class.
105108
Returns examples for data, model_config, sampler_config.
@@ -168,7 +171,7 @@ def load(cls, fname):
168171
169172
Returns
170173
-------
171-
Returns the inference data that is loaded from local system.
174+
Returns an instance of pm.Model, that is loaded from local data.
172175
173176
Raises
174177
------
@@ -185,7 +188,8 @@ def load(cls, fname):
185188

186189
filepath = Path(str(fname))
187190
idata = az.from_netcdf(filepath)
188-
self = cls(
191+
self = ModelBuilder(idata)
192+
self.model = cls(
189193
dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])),
190194
dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])),
191195
idata.fit_data.to_dataframe(),
@@ -197,7 +201,7 @@ def load(cls, fname):
197201
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'"
198202
)
199203

200-
return self
204+
return self.model
201205

202206
def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
203207
"""
@@ -227,7 +231,7 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
227231
self.build()
228232
self._data_setter(data)
229233

230-
with self:
234+
with self.model:
231235
self.idata = pm.sample(**self.sample_config)
232236
self.idata.extend(pm.sample_prior_predictive())
233237
self.idata.extend(pm.sample_posterior_predictive(self.idata))

pymc_experimental/tests/test_model_builder.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,9 @@
1313
# limitations under the License.
1414

1515

16-
import hashlib
17-
import sys
18-
import tempfile
19-
2016
import numpy as np
2117
import pandas as pd
2218
import pymc as pm
23-
import pytest
2419

2520
from pymc_experimental.model_builder import ModelBuilder
2621

@@ -30,25 +25,30 @@ class test_ModelBuilder(ModelBuilder):
3025
version = "0.1"
3126

3227
def build_model(self, model_config, data=None):
33-
if data is not None:
34-
x = pm.MutableData("x", data["input"].values)
35-
y_data = pm.MutableData("y_data", data["output"].values)
36-
37-
# prior parameters
38-
a_loc = model_config["a_loc"]
39-
a_scale = model_config["a_scale"]
40-
b_loc = model_config["b_loc"]
41-
b_scale = model_config["b_scale"]
42-
obs_error = model_config["obs_error"]
43-
44-
# priors
45-
a = pm.Normal("a", a_loc, sigma=a_scale)
46-
b = pm.Normal("b", b_loc, sigma=b_scale)
47-
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
48-
49-
# observed data
50-
if data is not None:
51-
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data)
28+
29+
self.model_config = model_config
30+
self.data = data
31+
32+
with pm.Model() as self.model:
33+
if data is not None:
34+
x = pm.MutableData("x", data["input"].values)
35+
y_data = pm.MutableData("y_data", data["output"].values)
36+
37+
# prior parameters
38+
a_loc = model_config["a_loc"]
39+
a_scale = model_config["a_scale"]
40+
b_loc = model_config["b_loc"]
41+
b_scale = model_config["b_scale"]
42+
obs_error = model_config["obs_error"]
43+
44+
# priors
45+
a = pm.Normal("a", a_loc, sigma=a_scale)
46+
b = pm.Normal("b", b_loc, sigma=b_scale)
47+
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
48+
49+
# observed data
50+
if data is not None:
51+
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data)
5252

5353
def _data_setter(self, data: pd.DataFrame):
5454
with self.model:
@@ -57,7 +57,7 @@ def _data_setter(self, data: pd.DataFrame):
5757
pm.set_data({"y_data": data["output"].values})
5858

5959
@classmethod
60-
def create_sample_input(cls):
60+
def create_sample_input(self):
6161
x = np.linspace(start=0, stop=1, num=100)
6262
y = 5 * x + 3
6363
y = y + np.random.normal(0, 1, len(x))
@@ -101,20 +101,21 @@ def test_fit():
101101
assert "y_model" in post_pred.keys()
102102

103103

104+
"""
104105
@pytest.mark.skipif(
105106
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
106107
)
107108
def test_save_load():
108-
model = test_ModelBuilder.initial_build_and_fit(False)
109+
test_builder = test_ModelBuilder.initial_build_and_fit(False)
109110
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
110-
model.save(temp.name)
111-
model2 = test_ModelBuilder.load(temp.name)
112-
assert model.idata.groups() == model2.idata.groups()
111+
test_builder.save(temp.name)
112+
test_builder2 = test_ModelBuilder.load(temp.name)
113+
assert test_builder.model.idata.groups() == test_builder2.model.idata.groups()
113114
114115
x_pred = np.random.uniform(low=0, high=1, size=100)
115116
prediction_data = pd.DataFrame({"input": x_pred})
116-
pred1 = model.predict(prediction_data)
117-
pred2 = model2.predict(prediction_data)
117+
pred1 = test_builder.predict(prediction_data)
118+
pred2 = test_builder2.predict(prediction_data)
118119
assert pred1["y_model"].shape == pred2["y_model"].shape
119120
temp.close()
120121
@@ -170,3 +171,4 @@ def test_id():
170171
).hexdigest()[:16]
171172
172173
assert model.id == expected_id
174+
"""

0 commit comments

Comments
 (0)