13
13
# limitations under the License.
14
14
15
15
import hashlib
16
+ import json
16
17
import sys
17
18
import tempfile
18
19
from typing import Dict
@@ -43,29 +44,35 @@ def toy_y(toy_X):
43
44
@pytest .fixture (scope = "module" )
44
45
def fitted_model_instance (toy_X , toy_y ):
45
46
sampler_config = {
46
- "draws" : 500 ,
47
- "tune" : 300 ,
47
+ "draws" : 100 ,
48
+ "tune" : 100 ,
48
49
"chains" : 2 ,
49
50
"target_accept" : 0.95 ,
50
51
}
51
52
model_config = {
52
- "a" : {"loc" : 0 , "scale" : 10 },
53
+ "a" : {"loc" : 0 , "scale" : 10 , "dims" : ( "numbers" ,) },
53
54
"b" : {"loc" : 0 , "scale" : 10 },
54
55
"obs_error" : 2 ,
55
56
}
56
- model = test_ModelBuilder (model_config = model_config , sampler_config = sampler_config )
57
+ model = test_ModelBuilder (
58
+ model_config = model_config , sampler_config = sampler_config , test_parameter = "test_paramter"
59
+ )
57
60
model .fit (toy_X )
58
61
return model
59
62
60
63
61
64
class test_ModelBuilder (ModelBuilder ):
65
+ def __init__ (self , model_config = None , sampler_config = None , test_parameter = None ):
66
+ self .test_parameter = test_parameter
67
+ super ().__init__ (model_config = model_config , sampler_config = sampler_config )
62
68
63
- _model_type = "LinearModel "
69
+ _model_type = "test_model "
64
70
version = "0.1"
65
71
66
72
def build_model (self , X : pd .DataFrame , y : pd .Series , model_config = None ):
73
+ coords = {"numbers" : np .arange (len (X ))}
67
74
self .generate_and_preprocess_model_data (X , y )
68
- with pm .Model () as self .model :
75
+ with pm .Model (coords = coords ) as self .model :
69
76
if model_config is None :
70
77
model_config = self .default_model_config
71
78
x = pm .MutableData ("x" , self .X ["input" ].values )
@@ -79,13 +86,16 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
79
86
obs_error = model_config ["obs_error" ]
80
87
81
88
# priors
82
- a = pm .Normal ("a" , a_loc , sigma = a_scale )
89
+ a = pm .Normal ("a" , a_loc , sigma = a_scale , dims = model_config [ "a" ][ "dims" ] )
83
90
b = pm .Normal ("b" , b_loc , sigma = b_scale )
84
91
obs_error = pm .HalfNormal ("σ_model_fmc" , obs_error )
85
92
86
93
# observed data
87
94
output = pm .Normal ("output" , a + b * x , obs_error , shape = x .shape , observed = y_data )
88
95
96
+ def _save_input_params (self , idata ):
97
+ idata .attrs ["test_paramter" ] = json .dumps (self .test_parameter )
98
+
89
99
@property
90
100
def output_var (self ):
91
101
return "output"
@@ -107,7 +117,7 @@ def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series):
107
117
@property
108
118
def default_model_config (self ) -> Dict :
109
119
return {
110
- "a" : {"loc" : 0 , "scale" : 10 },
120
+ "a" : {"loc" : 0 , "scale" : 10 , "dims" : ( "numbers" ,) },
111
121
"b" : {"loc" : 0 , "scale" : 10 },
112
122
"obs_error" : 2 ,
113
123
}
@@ -122,6 +132,38 @@ def default_sampler_config(self) -> Dict:
122
132
}
123
133
124
134
135
+ def test_save_input_params (fitted_model_instance ):
136
+ assert fitted_model_instance .idata .attrs ["test_paramter" ] == '"test_paramter"'
137
+
138
+
139
+ def test_save_load (fitted_model_instance ):
140
+ temp = tempfile .NamedTemporaryFile (mode = "w" , encoding = "utf-8" , delete = False )
141
+ fitted_model_instance .save (temp .name )
142
+ test_builder2 = test_ModelBuilder .load (temp .name )
143
+ assert fitted_model_instance .idata .groups () == test_builder2 .idata .groups ()
144
+ assert fitted_model_instance .id == test_builder2 .id
145
+ x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
146
+ prediction_data = pd .DataFrame ({"input" : x_pred })
147
+ pred1 = fitted_model_instance .predict (prediction_data ["input" ])
148
+ pred2 = test_builder2 .predict (prediction_data ["input" ])
149
+ assert pred1 .shape == pred2 .shape
150
+ temp .close ()
151
+
152
+
153
+ def test_convert_dims_to_tuple (fitted_model_instance ):
154
+ model_config = {
155
+ "a" : {
156
+ "loc" : 0 ,
157
+ "scale" : 10 ,
158
+ "dims" : [
159
+ "x" ,
160
+ ],
161
+ },
162
+ }
163
+ converted_model_config = fitted_model_instance ._convert_dims_to_tuple (model_config )
164
+ assert converted_model_config ["a" ]["dims" ] == ("x" ,)
165
+
166
+
125
167
def test_initial_build_and_fit (fitted_model_instance , check_idata = True ) -> ModelBuilder :
126
168
if check_idata :
127
169
assert fitted_model_instance .idata is not None
@@ -162,20 +204,6 @@ def test_fit_no_y(toy_X):
162
204
@pytest .mark .skipif (
163
205
sys .platform == "win32" , reason = "Permissions for temp files not granted on windows CI."
164
206
)
165
- def test_save_load (fitted_model_instance ):
166
- temp = tempfile .NamedTemporaryFile (mode = "w" , encoding = "utf-8" , delete = False )
167
- fitted_model_instance .save (temp .name )
168
- test_builder2 = test_ModelBuilder .load (temp .name )
169
- assert fitted_model_instance .idata .groups () == test_builder2 .idata .groups ()
170
-
171
- x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
172
- prediction_data = pd .DataFrame ({"input" : x_pred })
173
- pred1 = fitted_model_instance .predict (prediction_data ["input" ])
174
- pred2 = test_builder2 .predict (prediction_data ["input" ])
175
- assert pred1 .shape == pred2 .shape
176
- temp .close ()
177
-
178
-
179
207
def test_predict (fitted_model_instance ):
180
208
x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
181
209
prediction_data = pd .DataFrame ({"input" : x_pred })
0 commit comments