@@ -148,32 +148,35 @@ def test_sample_does_not_rely_on_external_global_seeding(self):
148
148
assert np .all (idata12 ["x" ] != idata22 ["x" ])
149
149
assert np .all (idata13 ["x" ] != idata23 ["x" ])
150
150
151
- def test_sample_init (self ):
151
+ @pytest .mark .parametrize (
152
+ "init" ,
153
+ (
154
+ "advi" ,
155
+ "advi_map" ,
156
+ "map" ,
157
+ "adapt_diag" ,
158
+ "jitter+adapt_diag" ,
159
+ "jitter+adapt_diag_grad" ,
160
+ "adapt_full" ,
161
+ "jitter+adapt_full" ,
162
+ ),
163
+ )
164
+ def test_sample_init (self , init ):
152
165
with self .model :
153
- for init in (
154
- "advi" ,
155
- "advi_map" ,
156
- "map" ,
157
- "adapt_diag" ,
158
- "jitter+adapt_diag" ,
159
- "jitter+adapt_diag_grad" ,
160
- "adapt_full" ,
161
- "jitter+adapt_full" ,
162
- ):
163
- kwargs = {
164
- "init" : init ,
165
- "tune" : 120 ,
166
- "n_init" : 1000 ,
167
- "draws" : 50 ,
168
- "random_seed" : 20160911 ,
169
- }
170
- with warnings .catch_warnings (record = True ) as rec :
171
- warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
172
- if init .endswith ("adapt_full" ):
173
- with pytest .warns (UserWarning , match = "experimental feature" ):
174
- pm .sample (** kwargs )
175
- else :
176
- pm .sample (** kwargs )
166
+ kwargs = {
167
+ "init" : init ,
168
+ "tune" : 120 ,
169
+ "n_init" : 1000 ,
170
+ "draws" : 50 ,
171
+ "random_seed" : 20160911 ,
172
+ }
173
+ with warnings .catch_warnings (record = True ) as rec :
174
+ warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
175
+ if init .endswith ("adapt_full" ):
176
+ with pytest .warns (UserWarning , match = "experimental feature" ):
177
+ pm .sample (** kwargs , cores = 1 )
178
+ else :
179
+ pm .sample (** kwargs , cores = 1 )
177
180
178
181
def test_sample_args (self ):
179
182
with self .model :
0 commit comments