@@ -213,17 +213,19 @@ def __init__(self, vars=None, S=None, proposal_dist=None, scaling=1.,
213
213
self .mode = mode
214
214
215
215
# flag to indicate this stepper was instantiated within an MLDA stepper
216
- # used to decide if the tuning parameters are reset when iter_sample () is called
216
+ # if not, tuning parameters are reset when _iter_sample () is called
217
217
self .is_mlda_base = kwargs .pop ("is_mlda_base" , False )
218
218
219
219
shared = pm .make_shared_replacements (vars , model )
220
220
self .delta_logp = delta_logp (model .logpt , vars , shared )
221
221
super ().__init__ (vars , shared )
222
222
223
223
def reset_tuning (self ):
224
- """Resets the tuned sampler parameters to their initial values."""
225
- for attr , initial_value in self ._untuned_settings .items ():
226
- setattr (self , attr , initial_value )
224
+ """Resets the tuned sampler parameters to their initial values.
225
+ Skipped if stepper is a bottom-level stepper in MLDA."""
226
+ if not self .is_mlda_base :
227
+ for attr , initial_value in self ._untuned_settings .items ():
228
+ setattr (self , attr , initial_value )
227
229
return
228
230
229
231
def astep (self , q0 ):
@@ -1106,12 +1108,12 @@ def astep(self, q0):
1106
1108
"""One MLDA step, given current sample q0"""
1107
1109
# Check if the tuning flag has been changed and if yes,
1108
1110
# change the proposal's tuning flag and reset self.accepted
1109
- # This is triggered by iter_sample while the highest-level MLDA step
1111
+ # This is triggered by _iter_sample while the highest-level MLDA step
1110
1112
# method is running. It then propagates to all levels.
1111
1113
if self .proposal_dist .tune != self .tune :
1112
1114
self .proposal_dist .tune = self .tune
1113
1115
# set tune in sub-methods of compound stepper explicitly because
1114
- # it is not set within sample() (only the CompoundStep's tune flag is)
1116
+ # it is not set within sample.py (only the CompoundStep's tune flag is)
1115
1117
if isinstance (self .next_step_method , CompoundStep ):
1116
1118
for method in self .next_step_method .methods :
1117
1119
method .tune = self .tune
0 commit comments