Skip to content

Commit db022a5

Browse files
authored
Merge pull request #39 from alan-turing-institute/mlda_develop
Improve stats calculation
2 parents 9278ddc + 9d25cd2 commit db022a5

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

pymc3/step_methods/metropolis.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,8 @@ class MLDA(ArrayStepShared):
981981
stats_dtypes = [{
982982
'accept': np.float64,
983983
'accepted': np.bool,
984-
'tune': np.bool
984+
'tune': np.bool,
985+
'base_scaling': object
985986
}]
986987

987988
def __init__(self, coarse_models, vars=None, base_S=None, base_proposal_dist=None,
@@ -1102,13 +1103,6 @@ def __init__(self, coarse_models, vars=None, base_S=None, base_proposal_dist=Non
11021103
self.tune,
11031104
self.subsampling_rates[-1])
11041105

1105-
# Update stats data types dictionary given vars and base_blocked
1106-
if self.base_blocked or len(self.vars) == 1:
1107-
self.stats_dtypes[0]['base_scaling'] = np.float64
1108-
else:
1109-
for name in self.var_names:
1110-
self.stats_dtypes[0]['base_scaling_' + name] = np.float64
1111-
11121106
def astep(self, q0):
11131107
"""One MLDA step, given current sample q0"""
11141108
# Check if the tuning flag has been changed and if yes,
@@ -1137,11 +1131,15 @@ def astep(self, q0):
11371131
# do not calculate likelihood, just set accept to 0.0
11381132
if (q == q0).all():
11391133
accept = np.float(0.0)
1134+
skipped_logp = True
11401135
else:
11411136
accept = self.delta_logp(q, q0) + self.delta_logp_next(q0, q)
1137+
skipped_logp = False
11421138

11431139
# Accept/reject sample - next sample is stored in q_new
11441140
q_new, accepted = metrop_select(accept, q, q0)
1141+
if skipped_logp:
1142+
accepted = False
11451143

11461144
# Update acceptance counter
11471145
self.accepted += accepted
@@ -1155,12 +1153,15 @@ def astep(self, q0):
11551153
# Capture latest base chain scaling stats from next step method
11561154
self.base_scaling_stats = {}
11571155
if isinstance(self.next_step_method, CompoundStep):
1156+
scaling_list = []
11581157
for method in self.next_step_method.methods:
1159-
self.base_scaling_stats["base_scaling_" + method.vars[0].name] = method.scaling
1160-
elif isinstance(self.next_step_method, Metropolis):
1161-
self.base_scaling_stats["base_scaling"] = self.next_step_method.scaling
1158+
scaling_list.append(method.scaling)
1159+
self.base_scaling_stats = {"base_scaling": np.array(scaling_list)}
1160+
elif not isinstance(self.next_step_method, MLDA):
1161+
# next method is any block sampler
1162+
self.base_scaling_stats = {"base_scaling": np.array(self.next_step_method.scaling)}
11621163
else:
1163-
# next method is MLDA
1164+
# next method is MLDA - propagate dict from lower levels
11641165
self.base_scaling_stats = self.next_step_method.base_scaling_stats
11651166
stats = {**stats, **self.base_scaling_stats}
11661167

pymc3/tests/test_step.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,10 +1126,10 @@ def test_acceptance_rate_against_coarseness(self):
11261126
Normal("x", 5.0, 1.0)
11271127

11281128
with Model() as coarse_model_1:
1129-
Normal("x", 5.5, 1.5)
1129+
Normal("x", 6.0, 2.0)
11301130

11311131
with Model() as coarse_model_2:
1132-
Normal("x", 6.0, 2.0)
1132+
Normal("x", 20.0, 5.0)
11331133

11341134
possible_coarse_models = [coarse_model_0,
11351135
coarse_model_1,
@@ -1139,9 +1139,9 @@ def test_acceptance_rate_against_coarseness(self):
11391139
with Model():
11401140
Normal("x", 5.0, 1.0)
11411141
for coarse_model in possible_coarse_models:
1142-
step = MLDA(coarse_models=[coarse_model], subsampling_rates=1,
1143-
tune=False)
1144-
trace = sample(chains=1, draws=500, tune=0, step=step)
1142+
step = MLDA(coarse_models=[coarse_model], subsampling_rates=3,
1143+
tune=True)
1144+
trace = sample(chains=1, draws=500, tune=100, step=step)
11451145
acc.append(trace.get_sampler_stats('accepted').mean())
11461146
assert acc[0] > acc[1] > acc[2], "Acceptance rate is not " \
11471147
"strictly increasing when" \
@@ -1197,10 +1197,10 @@ def test_tuning_and_scaling_on(self):
11971197
assert trace.get_sampler_stats('tune', chains=0)[ts - 1]
11981198
assert not trace.get_sampler_stats('tune', chains=0)[ts]
11991199
assert not trace.get_sampler_stats('tune', chains=0)[-1]
1200-
assert trace.get_sampler_stats('base_scaling_x', chains=0)[0] == 100.
1201-
assert trace.get_sampler_stats('base_scaling_y_logodds__', chains=0)[0] == 100.
1202-
assert trace.get_sampler_stats('base_scaling_x', chains=0)[-1] < 100.
1203-
assert trace.get_sampler_stats('base_scaling_y_logodds__', chains=0)[-1] < 100.
1200+
assert trace.get_sampler_stats('base_scaling', chains=0)[0][0] == 100.
1201+
assert trace.get_sampler_stats('base_scaling', chains=0)[0][1] == 100.
1202+
assert trace.get_sampler_stats('base_scaling', chains=0)[-1][0] < 100.
1203+
assert trace.get_sampler_stats('base_scaling', chains=0)[-1][1] < 100.
12041204

12051205
def test_tuning_and_scaling_off(self):
12061206
"""Test that tuning is deactivated when sample()'s tune=0 and that
@@ -1239,17 +1239,19 @@ def test_tuning_and_scaling_off(self):
12391239

12401240
assert not trace_0.get_sampler_stats('tune', chains=0)[0]
12411241
assert not trace_0.get_sampler_stats('tune', chains=0)[-1]
1242-
assert trace_0.get_sampler_stats('base_scaling_x', chains=0)[0] == \
1243-
trace_0.get_sampler_stats('base_scaling_x', chains=0)[-1] == 100.
1242+
assert trace_0.get_sampler_stats('base_scaling', chains=0)[0][0] == \
1243+
trace_0.get_sampler_stats('base_scaling', chains=0)[-1][0] == \
1244+
trace_0.get_sampler_stats('base_scaling', chains=0)[0][1] == \
1245+
trace_0.get_sampler_stats('base_scaling', chains=0)[-1][1] == 100.
12441246

12451247
assert trace_1.get_sampler_stats('tune', chains=0)[0]
12461248
assert trace_1.get_sampler_stats('tune', chains=0)[ts_1 - 1]
12471249
assert not trace_1.get_sampler_stats('tune', chains=0)[ts_1]
12481250
assert not trace_1.get_sampler_stats('tune', chains=0)[-1]
1249-
assert trace_1.get_sampler_stats('base_scaling_x', chains=0)[0] == 100.
1250-
assert trace_1.get_sampler_stats('base_scaling_y_logodds__', chains=0)[0] == 100.
1251-
assert trace_1.get_sampler_stats('base_scaling_x', chains=0)[-1] < 100.
1252-
assert trace_1.get_sampler_stats('base_scaling_y_logodds__', chains=0)[-1] < 100.
1251+
assert trace_1.get_sampler_stats('base_scaling', chains=0)[0][0] == 100.
1252+
assert trace_1.get_sampler_stats('base_scaling', chains=0)[0][1] == 100.
1253+
assert trace_1.get_sampler_stats('base_scaling', chains=0)[-1][0] < 100.
1254+
assert trace_1.get_sampler_stats('base_scaling', chains=0)[-1][1] < 100.
12531255

12541256
def test_trace_length(self):
12551257
"""Check if trace length is as expected."""

0 commit comments

Comments
 (0)