Skip to content

Commit 812e60e

Browse files
Reset tuned Metropolis parameters in sequential sampling of chains (#3796)
* add regression test for #3733 * implement reset_tuning on Metropolis closes #3733 * mention fix of #3733
1 parent bb574a7 commit 812e60e

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
### Maintenance
1212
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.
13+
- Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796).
1314

1415
## PyMC3 3.8 (November 29 2019)
1516

pymc3/step_methods/metropolis.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,25 @@ def __init__(self, vars=None, S=None, proposal_dist=None, scaling=1.,
146146
self.any_discrete = self.discrete.any()
147147
self.all_discrete = self.discrete.all()
148148

149+
# remember initial settings before tuning so they can be reset
150+
self._untuned_settings = dict(
151+
scaling=self.scaling,
152+
steps_until_tune=tune_interval,
153+
accepted=self.accepted
154+
)
155+
149156
self.mode = mode
150157

151158
shared = pm.make_shared_replacements(vars, model)
152159
self.delta_logp = delta_logp(model.logpt, vars, shared)
153160
super().__init__(vars, shared)
154161

162+
def reset_tuning(self):
163+
"""Resets the tuned sampler parameters to their initial values."""
164+
for attr, initial_value in self._untuned_settings.items():
165+
setattr(self, attr, initial_value)
166+
return
167+
155168
def astep(self, q0):
156169
if not self.steps_until_tune and self.tune:
157170
# Tune scaling parameter

pymc3/tests/test_step.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,27 @@ def test_parallelized_chains_are_random(self):
779779
pass
780780

781781

782+
class TestMetropolis:
783+
def test_tuning_reset(self):
784+
"""Re-use of the step method instance with cores=1 must not leak tuning information between chains."""
785+
with Model() as pmodel:
786+
D = 3
787+
Normal('n', 0, 2, shape=(D,))
788+
trace = sample(
789+
tune=600,
790+
draws=500,
791+
step=Metropolis(tune=True, scaling=0.1),
792+
cores=1,
793+
chains=3,
794+
discard_tuned_samples=False
795+
)
796+
for c in range(trace.nchains):
797+
# check that the tuned settings changed and were reset
798+
assert trace.get_sampler_stats('scaling', chains=c)[0] == 0.1
799+
assert trace.get_sampler_stats('scaling', chains=c)[-1] != 0.1
800+
pass
801+
802+
782803
class TestDEMetropolisZ:
783804
def test_tuning_lambda_sequential(self):
784805
with Model() as pmodel:

0 commit comments

Comments
 (0)