13
13
# limitations under the License.
14
14
15
15
16
+ import hashlib
16
17
import sys
17
18
import tempfile
18
19
@@ -79,14 +80,19 @@ def create_sample_input(cls):
79
80
80
81
return data , model_config , sampler_config
81
82
83
+ @staticmethod
84
+ def initial_build_and_fit (check_idata = True ):
85
+ data , model_config , sampler_config = test_ModelBuilder .create_sample_input ()
86
+ model = test_ModelBuilder (model_config , sampler_config , data )
87
+ model .fit ()
88
+ if check_idata :
89
+ assert model .idata is not None
90
+ assert "posterior" in model .idata .groups ()
91
+ return model
82
92
83
- def test_fit ():
84
- data , model_config , sampler_config = test_ModelBuilder .create_sample_input ()
85
- model = test_ModelBuilder (model_config , sampler_config , data )
86
- model .fit ()
87
- assert model .idata is not None
88
- assert "posterior" in model .idata .groups ()
89
93
94
+ def test_fit ():
95
+ model = test_ModelBuilder .initial_build_and_fit ()
90
96
x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
91
97
prediction_data = pd .DataFrame ({"input" : x_pred })
92
98
pred = model .predict (prediction_data )
@@ -99,9 +105,7 @@ def test_fit():
99
105
sys .platform == "win32" , reason = "Permissions for temp files not granted on windows CI."
100
106
)
101
107
def test_save_load ():
102
- data , model_config , sampler_config = test_ModelBuilder .create_sample_input ()
103
- model = test_ModelBuilder (model_config , sampler_config , data )
104
- model .fit ()
108
+ model = test_ModelBuilder .initial_build_and_fit (False )
105
109
temp = tempfile .NamedTemporaryFile (mode = "w" , encoding = "utf-8" , delete = False )
106
110
model .save (temp .name )
107
111
model2 = test_ModelBuilder .load (temp .name )
@@ -113,3 +117,56 @@ def test_save_load():
113
117
pred2 = model2 .predict (prediction_data )
114
118
assert pred1 ["y_model" ].shape == pred2 ["y_model" ].shape
115
119
temp .close ()
120
+
121
+
122
+ def test_predict ():
123
+ model = test_ModelBuilder .initial_build_and_fit ()
124
+ x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
125
+ prediction_data = pd .DataFrame ({"input" : x_pred })
126
+ pred = model .predict (prediction_data )
127
+ assert "y_model" in pred
128
+ assert isinstance (pred , dict )
129
+ assert len (prediction_data .input .values ) == len (pred ["y_model" ])
130
+ assert isinstance (pred ["y_model" ][0 ], float )
131
+
132
+
133
+ def test_predict_posterior ():
134
+ model = test_ModelBuilder .initial_build_and_fit ()
135
+ x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
136
+ prediction_data = pd .DataFrame ({"input" : x_pred })
137
+ pred = model .predict_posterior (prediction_data )
138
+ assert "y_model" in pred
139
+ assert isinstance (pred , dict )
140
+ assert len (prediction_data .input .values ) == len (pred ["y_model" ][0 ])
141
+ assert isinstance (pred ["y_model" ][0 ], np .ndarray )
142
+
143
+
144
+ def test_extract_samples ():
145
+ # create a fake InferenceData object
146
+ with pm .Model () as model :
147
+ x = pm .Normal ("x" , mu = 0 , sigma = 1 )
148
+ intercept = pm .Normal ("intercept" , mu = 0 , sigma = 1 )
149
+ y_model = pm .Normal ("y_model" , mu = x * intercept , sigma = 1 , observed = [0 , 1 , 2 ])
150
+
151
+ idata = pm .sample (1000 , tune = 1000 )
152
+ post_pred = pm .sample_posterior_predictive (idata )
153
+
154
+ # call the function and get the output
155
+ samples_dict = test_ModelBuilder ._extract_samples (post_pred )
156
+
157
+ # assert that the keys and values are correct
158
+ assert len (samples_dict ) == len (post_pred .posterior_predictive )
159
+ for key in post_pred .posterior_predictive :
160
+ expected_value = post_pred .posterior_predictive [key ].to_numpy ()[0 ]
161
+ assert np .array_equal (samples_dict [key ], expected_value )
162
+
163
+
164
+ def test_id ():
165
+ data , model_config , sampler_config = test_ModelBuilder .create_sample_input ()
166
+ model = test_ModelBuilder (model_config , sampler_config , data )
167
+
168
+ expected_id = hashlib .sha256 (
169
+ str (model_config .values ()).encode () + model .version .encode () + model ._model_type .encode ()
170
+ ).hexdigest ()[:16 ]
171
+
172
+ assert model .id == expected_id
0 commit comments