12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import hashlib
16
+ import sys
17
+ import tempfile
15
18
16
19
import numpy as np
17
20
import pandas as pd
18
21
import pymc as pm
22
+ import pytest
19
23
20
24
from pymc_experimental .model_builder import ModelBuilder
21
25
@@ -24,12 +28,13 @@ class test_ModelBuilder(ModelBuilder):
24
28
_model_type = "LinearModel"
25
29
version = "0.1"
26
30
27
- def build_model (self , model_config , data = None ):
28
-
31
+ def build_model (self , model_instance , model_config , data = None ):
32
+ model_instance .model_config = model_config
33
+ model_instance .data = data
29
34
self .model_config = model_config
30
35
self .data = data
31
36
32
- with pm .Model () as self .model :
37
+ with pm .Model () as model_instance .model :
33
38
if data is not None :
34
39
x = pm .MutableData ("x" , data ["input" ].values )
35
40
y_data = pm .MutableData ("y_data" , data ["output" ].values )
@@ -83,12 +88,12 @@ def create_sample_input(self):
83
88
@staticmethod
84
89
def initial_build_and_fit (check_idata = True ):
85
90
data , model_config , sampler_config = test_ModelBuilder .create_sample_input ()
86
- model = test_ModelBuilder (model_config , sampler_config , data )
87
- model .fit (data = data )
91
+ model_builder = test_ModelBuilder (model_config , sampler_config , data )
92
+ model_builder . idata = model_builder .fit (data = data )
88
93
if check_idata :
89
- assert model .idata is not None
90
- assert "posterior" in model .idata .groups ()
91
- return model
94
+ assert model_builder .idata is not None
95
+ assert "posterior" in model_builder .idata .groups ()
96
+ return model_builder
92
97
93
98
94
99
def test_fit ():
@@ -101,16 +106,16 @@ def test_fit():
101
106
assert "y_model" in post_pred .keys ()
102
107
103
108
104
- """
105
109
@pytest .mark .skipif (
106
110
sys .platform == "win32" , reason = "Permissions for temp files not granted on windows CI."
107
111
)
108
112
def test_save_load ():
109
- test_builder = test_ModelBuilder.initial_build_and_fit(False )
113
+ test_builder = test_ModelBuilder .initial_build_and_fit ()
110
114
temp = tempfile .NamedTemporaryFile (mode = "w" , encoding = "utf-8" , delete = False )
111
115
test_builder .save (temp .name )
112
- test_builder2 = test_ModelBuilder.load(temp.name)
113
- assert test_builder.model.idata.groups() == test_builder2.model.idata.groups()
116
+ test_builder2 = test_ModelBuilder .initial_build_and_fit ()
117
+ test_builder2 .model = test_ModelBuilder .load (temp .name )
118
+ assert test_builder .idata .groups () == test_builder2 .idata .groups ()
114
119
115
120
x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
116
121
prediction_data = pd .DataFrame ({"input" : x_pred })
@@ -171,4 +176,3 @@ def test_id():
171
176
).hexdigest ()[:16 ]
172
177
173
178
assert model .id == expected_id
174
- """
0 commit comments