Skip to content

Commit 064e3ae

Browse files
ModelBuilder no longer extends pm.Model, pm.Model now is an attribute
1 parent 2e4507d commit 064e3ae

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

pymc_experimental/model_builder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def build(self):
7070
Builds the defined model.
7171
"""
7272

73-
self.build_model(self.model_config, self.data)
73+
self.build_model(self, self.model_config, self.data)
7474

7575
@abstractmethod
7676
def _data_setter(
@@ -188,8 +188,7 @@ def load(cls, fname):
188188

189189
filepath = Path(str(fname))
190190
idata = az.from_netcdf(filepath)
191-
self = ModelBuilder(idata)
192-
self.model = cls(
191+
self = cls(
193192
dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])),
194193
dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])),
195194
idata.fit_data.to_dataframe(),

pymc_experimental/tests/test_model_builder.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import hashlib
16+
import sys
17+
import tempfile
1518

1619
import numpy as np
1720
import pandas as pd
1821
import pymc as pm
22+
import pytest
1923

2024
from pymc_experimental.model_builder import ModelBuilder
2125

@@ -24,12 +28,13 @@ class test_ModelBuilder(ModelBuilder):
2428
_model_type = "LinearModel"
2529
version = "0.1"
2630

27-
def build_model(self, model_config, data=None):
28-
31+
def build_model(self, model_instance, model_config, data=None):
32+
model_instance.model_config = model_config
33+
model_instance.data = data
2934
self.model_config = model_config
3035
self.data = data
3136

32-
with pm.Model() as self.model:
37+
with pm.Model() as model_instance.model:
3338
if data is not None:
3439
x = pm.MutableData("x", data["input"].values)
3540
y_data = pm.MutableData("y_data", data["output"].values)
@@ -83,12 +88,12 @@ def create_sample_input(self):
8388
@staticmethod
8489
def initial_build_and_fit(check_idata=True):
8590
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
86-
model = test_ModelBuilder(model_config, sampler_config, data)
87-
model.fit(data=data)
91+
model_builder = test_ModelBuilder(model_config, sampler_config, data)
92+
model_builder.idata = model_builder.fit(data=data)
8893
if check_idata:
89-
assert model.idata is not None
90-
assert "posterior" in model.idata.groups()
91-
return model
94+
assert model_builder.idata is not None
95+
assert "posterior" in model_builder.idata.groups()
96+
return model_builder
9297

9398

9499
def test_fit():
@@ -101,16 +106,16 @@ def test_fit():
101106
assert "y_model" in post_pred.keys()
102107

103108

104-
"""
105109
@pytest.mark.skipif(
106110
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
107111
)
108112
def test_save_load():
109-
test_builder = test_ModelBuilder.initial_build_and_fit(False)
113+
test_builder = test_ModelBuilder.initial_build_and_fit()
110114
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
111115
test_builder.save(temp.name)
112-
test_builder2 = test_ModelBuilder.load(temp.name)
113-
assert test_builder.model.idata.groups() == test_builder2.model.idata.groups()
116+
test_builder2 = test_ModelBuilder.initial_build_and_fit()
117+
test_builder2.model = test_ModelBuilder.load(temp.name)
118+
assert test_builder.idata.groups() == test_builder2.idata.groups()
114119

115120
x_pred = np.random.uniform(low=0, high=1, size=100)
116121
prediction_data = pd.DataFrame({"input": x_pred})
@@ -171,4 +176,3 @@ def test_id():
171176
).hexdigest()[:16]
172177

173178
assert model.id == expected_id
174-
"""

0 commit comments

Comments
 (0)