@@ -1126,10 +1126,10 @@ def test_acceptance_rate_against_coarseness(self):
1126
1126
Normal ("x" , 5.0 , 1.0 )
1127
1127
1128
1128
with Model () as coarse_model_1 :
1129
- Normal ("x" , 5.5 , 1.5 )
1129
+ Normal ("x" , 6.0 , 2.0 )
1130
1130
1131
1131
with Model () as coarse_model_2 :
1132
- Normal ("x" , 6 .0 , 2 .0 )
1132
+ Normal ("x" , 20 .0 , 5 .0 )
1133
1133
1134
1134
possible_coarse_models = [coarse_model_0 ,
1135
1135
coarse_model_1 ,
@@ -1139,9 +1139,9 @@ def test_acceptance_rate_against_coarseness(self):
1139
1139
with Model ():
1140
1140
Normal ("x" , 5.0 , 1.0 )
1141
1141
for coarse_model in possible_coarse_models :
1142
- step = MLDA (coarse_models = [coarse_model ], subsampling_rates = 1 ,
1143
- tune = False )
1144
- trace = sample (chains = 1 , draws = 500 , tune = 0 , step = step )
1142
+ step = MLDA (coarse_models = [coarse_model ], subsampling_rates = 3 ,
1143
+ tune = True )
1144
+ trace = sample (chains = 1 , draws = 500 , tune = 100 , step = step )
1145
1145
acc .append (trace .get_sampler_stats ('accepted' ).mean ())
1146
1146
assert acc [0 ] > acc [1 ] > acc [2 ], "Acceptance rate is not " \
1147
1147
"strictly increasing when" \
@@ -1197,10 +1197,10 @@ def test_tuning_and_scaling_on(self):
1197
1197
assert trace .get_sampler_stats ('tune' , chains = 0 )[ts - 1 ]
1198
1198
assert not trace .get_sampler_stats ('tune' , chains = 0 )[ts ]
1199
1199
assert not trace .get_sampler_stats ('tune' , chains = 0 )[- 1 ]
1200
- assert trace .get_sampler_stats ('base_scaling_x ' , chains = 0 )[0 ] == 100.
1201
- assert trace .get_sampler_stats ('base_scaling_y_logodds__ ' , chains = 0 )[0 ] == 100.
1202
- assert trace .get_sampler_stats ('base_scaling_x ' , chains = 0 )[- 1 ] < 100.
1203
- assert trace .get_sampler_stats ('base_scaling_y_logodds__ ' , chains = 0 )[- 1 ] < 100.
1200
+ assert trace .get_sampler_stats ('base_scaling ' , chains = 0 )[ 0 ] [0 ] == 100.
1201
+ assert trace .get_sampler_stats ('base_scaling ' , chains = 0 )[0 ][ 1 ] == 100.
1202
+ assert trace .get_sampler_stats ('base_scaling ' , chains = 0 )[- 1 ][ 0 ] < 100.
1203
+ assert trace .get_sampler_stats ('base_scaling ' , chains = 0 )[- 1 ][ 1 ] < 100.
1204
1204
1205
1205
def test_tuning_and_scaling_off (self ):
1206
1206
"""Test that tuning is deactivated when sample()'s tune=0 and that
@@ -1239,17 +1239,19 @@ def test_tuning_and_scaling_off(self):
1239
1239
1240
1240
assert not trace_0 .get_sampler_stats ('tune' , chains = 0 )[0 ]
1241
1241
assert not trace_0 .get_sampler_stats ('tune' , chains = 0 )[- 1 ]
1242
- assert trace_0 .get_sampler_stats ('base_scaling_x' , chains = 0 )[0 ] == \
1243
- trace_0 .get_sampler_stats ('base_scaling_x' , chains = 0 )[- 1 ] == 100.
1242
+ assert trace_0 .get_sampler_stats ('base_scaling' , chains = 0 )[0 ][0 ] == \
1243
+ trace_0 .get_sampler_stats ('base_scaling' , chains = 0 )[- 1 ][0 ] == \
1244
+ trace_0 .get_sampler_stats ('base_scaling' , chains = 0 )[0 ][1 ] == \
1245
+ trace_0 .get_sampler_stats ('base_scaling' , chains = 0 )[- 1 ][1 ] == 100.
1244
1246
1245
1247
assert trace_1 .get_sampler_stats ('tune' , chains = 0 )[0 ]
1246
1248
assert trace_1 .get_sampler_stats ('tune' , chains = 0 )[ts_1 - 1 ]
1247
1249
assert not trace_1 .get_sampler_stats ('tune' , chains = 0 )[ts_1 ]
1248
1250
assert not trace_1 .get_sampler_stats ('tune' , chains = 0 )[- 1 ]
1249
- assert trace_1 .get_sampler_stats ('base_scaling_x ' , chains = 0 )[0 ] == 100.
1250
- assert trace_1 .get_sampler_stats ('base_scaling_y_logodds__ ' , chains = 0 )[0 ] == 100.
1251
- assert trace_1 .get_sampler_stats ('base_scaling_x ' , chains = 0 )[- 1 ] < 100.
1252
- assert trace_1 .get_sampler_stats ('base_scaling_y_logodds__ ' , chains = 0 )[- 1 ] < 100.
1251
+ assert trace_1 .get_sampler_stats ('base_scaling ' , chains = 0 )[ 0 ] [0 ] == 100.
1252
+ assert trace_1 .get_sampler_stats ('base_scaling ' , chains = 0 )[0 ][ 1 ] == 100.
1253
+ assert trace_1 .get_sampler_stats ('base_scaling ' , chains = 0 )[- 1 ][ 0 ] < 100.
1254
+ assert trace_1 .get_sampler_stats ('base_scaling ' , chains = 0 )[- 1 ][ 1 ] < 100.
1253
1255
1254
1256
def test_trace_length (self ):
1255
1257
"""Check if trace length is as expected."""
0 commit comments