6
6
import pytest
7
7
8
8
9
- STATS1 = [{
10
- 'a' : np .float64 ,
11
- 'b' : np .bool
12
- }]
9
+ STATS1 = [{"a" : np .float64 , "b" : np .bool }]
13
10
14
- STATS2 = [{
15
- 'a' : np .float64
16
- }, {
17
- 'a' : np .float64 ,
18
- 'b' : np .int64 ,
19
- }]
11
+ STATS2 = [{"a" : np .float64 }, {"a" : np .float64 , "b" : np .int64 ,}]
20
12
21
13
22
14
class TestNDArray0dSampling (bf .SamplingTestCase ):
@@ -152,7 +144,7 @@ class TestMultiTrace_add_remove_values(bf.ModelBackendSampledTestCase):
152
144
def test_add_values (self ):
153
145
mtrace = self .mtrace
154
146
orig_varnames = list (mtrace .varnames )
155
- name = ' new_var'
147
+ name = " new_var"
156
148
vals = mtrace [orig_varnames [0 ]]
157
149
mtrace .add_values ({name : vals })
158
150
assert len (orig_varnames ) == len (mtrace .varnames ) - 1
@@ -164,7 +156,6 @@ def test_add_values(self):
164
156
165
157
166
158
class TestSqueezeCat :
167
-
168
159
def setup_method (self ):
169
160
self .x = np .arange (10 )
170
161
self .y = np .arange (10 , 20 )
@@ -194,13 +185,14 @@ def test_combine_true_squeeze_true(self):
194
185
result = base ._squeeze_cat ([self .x , self .y ], True , True )
195
186
npt .assert_equal (result , expected )
196
187
188
+
197
189
class TestSaveLoad :
198
190
@staticmethod
199
191
def model ():
200
192
with pm .Model () as model :
201
- x = pm .Normal ('x' , 0 , 1 )
202
- y = pm .Normal ('y' , x , 1 , observed = 2 )
203
- z = pm .Normal ('z' , x + y , 1 )
193
+ x = pm .Normal ("x" , 0 , 1 )
194
+ y = pm .Normal ("y" , x , 1 , observed = 2 )
195
+ z = pm .Normal ("z" , x + y , 1 )
204
196
return model
205
197
206
198
@classmethod
@@ -209,12 +201,12 @@ def setup_class(cls):
209
201
cls .trace = pm .sample ()
210
202
211
203
def test_save_new_model (self , tmpdir_factory ):
212
- directory = str (tmpdir_factory .mktemp (' data' ))
204
+ directory = str (tmpdir_factory .mktemp (" data" ))
213
205
save_dir = pm .save_trace (self .trace , directory , overwrite = True )
214
206
215
207
assert save_dir == directory
216
208
with pm .Model () as model :
217
- w = pm .Normal ('w' , 0 , 1 )
209
+ w = pm .Normal ("w" , 0 , 1 )
218
210
new_trace = pm .sample ()
219
211
220
212
with pytest .raises (OSError ):
@@ -224,26 +216,32 @@ def test_save_new_model(self, tmpdir_factory):
224
216
with model :
225
217
new_trace_copy = pm .load_trace (directory )
226
218
227
- assert (new_trace ['w' ] == new_trace_copy ['w' ]).all ()
219
+ assert (new_trace ["w" ] == new_trace_copy ["w" ]).all ()
228
220
229
221
def test_save_and_load (self , tmpdir_factory ):
230
- directory = str (tmpdir_factory .mktemp (' data' ))
222
+ directory = str (tmpdir_factory .mktemp (" data" ))
231
223
save_dir = pm .save_trace (self .trace , directory , overwrite = True )
232
224
233
225
assert save_dir == directory
234
226
235
227
trace2 = pm .load_trace (directory , model = TestSaveLoad .model ())
236
228
237
- for var in ('x' , 'z' ):
229
+ for var in ("x" , "z" ):
238
230
assert (self .trace [var ] == trace2 [var ]).all ()
239
231
232
+ assert self .trace .stat_names == trace2 .stat_names
233
+ for stat in self .trace .stat_names :
234
+ assert all (self .trace [stat ] == trace2 [stat ]), (
235
+ "Restored value of statistic %s does not match stored value" % stat
236
+ )
237
+
240
238
def test_bad_load (self , tmpdir_factory ):
241
- directory = str (tmpdir_factory .mktemp (' data' ))
239
+ directory = str (tmpdir_factory .mktemp (" data" ))
242
240
with pytest .raises (pm .TraceDirectoryError ):
243
241
pm .load_trace (directory , model = TestSaveLoad .model ())
244
242
245
243
def test_sample_posterior_predictive (self , tmpdir_factory ):
246
- directory = str (tmpdir_factory .mktemp (' data' ))
244
+ directory = str (tmpdir_factory .mktemp (" data" ))
247
245
save_dir = pm .save_trace (self .trace , directory , overwrite = True )
248
246
249
247
assert save_dir == directory
0 commit comments