Skip to content

Commit 94322ba

Browse files
committed
Move is_mlda_base checks for reseting inside Metropolis' reset_tuning()
1 parent 9f9c395 commit 94322ba

File tree

2 files changed

+9
-18
lines changed

2 files changed

+9
-18
lines changed

pymc3/sampling.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -903,18 +903,7 @@ def _iter_sample(
903903
try:
904904
step.tune = bool(tune)
905905
if hasattr(step, 'reset_tuning'):
906-
if isinstance(step, CompoundStep):
907-
if hasattr(step.methods[0], 'is_mlda_base'):
908-
if not step.methods[0].is_mlda_base:
909-
step.reset_tuning()
910-
else:
911-
step.reset_tuning()
912-
else:
913-
if hasattr(step, 'is_mlda_base'):
914-
if not step.is_mlda_base:
915-
step.reset_tuning()
916-
else:
917-
step.reset_tuning()
906+
step.reset_tuning()
918907
for i in range(draws):
919908
stats = None
920909
diverging = False

pymc3/step_methods/metropolis.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,19 @@ def __init__(self, vars=None, S=None, proposal_dist=None, scaling=1.,
213213
self.mode = mode
214214

215215
# flag to indicate this stepper was instantiated within an MLDA stepper
216-
# used to decide if the tuning parameters are reset when iter_sample() is called
216+
# if not, tuning parameters are reset when _iter_sample() is called
217217
self.is_mlda_base = kwargs.pop("is_mlda_base", False)
218218

219219
shared = pm.make_shared_replacements(vars, model)
220220
self.delta_logp = delta_logp(model.logpt, vars, shared)
221221
super().__init__(vars, shared)
222222

223223
def reset_tuning(self):
224-
"""Resets the tuned sampler parameters to their initial values."""
225-
for attr, initial_value in self._untuned_settings.items():
226-
setattr(self, attr, initial_value)
224+
"""Resets the tuned sampler parameters to their initial values.
225+
Skipped if stepper is a bottom-level stepper in MLDA."""
226+
if not self.is_mlda_base:
227+
for attr, initial_value in self._untuned_settings.items():
228+
setattr(self, attr, initial_value)
227229
return
228230

229231
def astep(self, q0):
@@ -1106,12 +1108,12 @@ def astep(self, q0):
11061108
"""One MLDA step, given current sample q0"""
11071109
# Check if the tuning flag has been changed and if yes,
11081110
# change the proposal's tuning flag and reset self.accepted
1109-
# This is triggered by iter_sample while the highest-level MLDA step
1111+
# This is triggered by _iter_sample while the highest-level MLDA step
11101112
# method is running. It then propagates to all levels.
11111113
if self.proposal_dist.tune != self.tune:
11121114
self.proposal_dist.tune = self.tune
11131115
# set tune in sub-methods of compound stepper explicitly because
1114-
# it is not set within sample() (only the CompoundStep's tune flag is)
1116+
# it is not set within sample.py (only the CompoundStep's tune flag is)
11151117
if isinstance(self.next_step_method, CompoundStep):
11161118
for method in self.next_step_method.methods:
11171119
method.tune = self.tune

0 commit comments

Comments
 (0)