Skip to content

Commit 680dadd

Browse files
authored
Fixing HMC et al. tuning schedule to reset at the beginning of a chain even on 1 core (#3941)
* fixing #3939 * fixing introduced bug in fulladapt quadpot
1 parent b6a88f0 commit 680dadd

File tree

4 files changed

+51
-19
lines changed

4 files changed

+51
-19
lines changed

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def astep(self, q0):
198198

199199
return hmc_step.end.q, [stats]
200200

201+
def reset_tuning(self, start=None):
202+
self.step_adapt.reset()
203+
self.reset(start=None)
204+
201205
def reset(self, start=None):
202206
self.tune = True
203207
self.potential.reset()

pymc3/step_methods/hmc/quadpotential.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,24 @@ def __init__(
171171

172172
self.dtype = dtype
173173
self._n = n
174-
self._var = np.array(initial_diag, dtype=self.dtype, copy=True)
174+
175+
self._initial_mean = initial_mean
176+
self._initial_diag = initial_diag
177+
self._initial_weight = initial_weight
178+
self.adaptation_window = adaptation_window
179+
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
180+
181+
self.reset()
182+
183+
def reset(self):
184+
self._var = np.array(self._initial_diag, dtype=self.dtype, copy=True)
175185
self._var_theano = theano.shared(self._var)
176-
self._stds = np.sqrt(initial_diag)
186+
self._stds = np.sqrt(self._initial_diag)
177187
self._inv_stds = floatX(1.) / self._stds
178188
self._foreground_var = _WeightedVariance(
179-
self._n, initial_mean, initial_diag, initial_weight, self.dtype)
189+
self._n, self._initial_mean, self._initial_diag, self._initial_weight, self.dtype)
180190
self._background_var = _WeightedVariance(self._n, dtype=self.dtype)
181191
self._n_samples = 0
182-
self.adaptation_window = adaptation_window
183-
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
184192

185193
def velocity(self, x, out=None):
186194
"""Compute the current velocity at a position in parameter space."""
@@ -275,8 +283,8 @@ class QuadPotentialDiagAdaptGrad(QuadPotentialDiagAdapt):
275283
This is experimental, and may be removed without prior deprication.
276284
"""
277285

278-
def __init__(self, *args, **kwargs):
279-
super().__init__(*args, **kwargs)
286+
def reset(self):
287+
super().reset()
280288
self._grads1 = np.zeros(self._n, dtype=self.dtype)
281289
self._ngrads1 = 0
282290
self._grads2 = np.zeros(self._n, dtype=self.dtype)
@@ -518,20 +526,27 @@ def __init__(
518526

519527
self.dtype = dtype
520528
self._n = n
521-
self._cov = np.array(initial_cov, dtype=self.dtype, copy=True)
529+
self._initial_mean = initial_mean
530+
self._initial_cov = initial_cov
531+
self._initial_weight = initial_weight
532+
533+
self.adaptation_window = int(adaptation_window)
534+
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
535+
self._update_window = int(update_window)
536+
537+
self.reset()
538+
539+
def reset(self):
540+
self._previous_update = 0
541+
self._cov = np.array(self._initial_cov, dtype=self.dtype, copy=True)
522542
self._chol = scipy.linalg.cholesky(self._cov, lower=True)
523543
self._chol_error = None
524544
self._foreground_cov = _WeightedCovariance(
525-
self._n, initial_mean, initial_cov, initial_weight, self.dtype
545+
self._n, self._initial_mean, self._initial_cov, self._initial_weight, self.dtype
526546
)
527547
self._background_cov = _WeightedCovariance(self._n, dtype=self.dtype)
528548
self._n_samples = 0
529549

530-
self.adaptation_window = int(adaptation_window)
531-
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
532-
self._update_window = int(update_window)
533-
self._previous_update = 0
534-
535550
def _update_from_weightvar(self, weightvar):
536551
weightvar.current_covariance(out=self._cov)
537552
try:

pymc3/step_methods/step_sizes.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,19 @@
2020

2121
class DualAverageAdaptation:
2222
def __init__(self, initial_step, target, gamma, k, t0):
23-
self._log_step = np.log(initial_step)
24-
self._log_bar = self._log_step
23+
self._initial_step = initial_step
2524
self._target = target
26-
self._hbar = 0.
2725
self._k = k
2826
self._t0 = t0
29-
self._count = 1
30-
self._mu = np.log(10 * initial_step)
3127
self._gamma = gamma
28+
self.reset()
29+
30+
def reset(self):
31+
self._log_step = np.log(self._initial_step)
32+
self._log_bar = self._log_step
33+
self._hbar = 0.
34+
self._count = 1
35+
self._mu = np.log(10 * self._initial_step)
3236
self._tuned_stats = []
3337

3438
def current(self, tune):

pymc3/tests/test_sampling.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ def test_sample_tune_len(self):
145145
trace = pm.sample(draws=100, tune=50, cores=4)
146146
assert len(trace) == 100
147147

148+
def test_reset_tuning(self):
149+
with self.model:
150+
tune = 50
151+
chains = 2
152+
start, step = pm.sampling.init_nuts(chains=chains)
153+
pm.sample(draws=2, tune=tune, chains=chains, step=step, start=start, cores=1)
154+
assert step.potential._n_samples == tune
155+
assert step.step_adapt._count == tune + 1
156+
148157
@pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice])
149158
@pytest.mark.parametrize("discard", [True, False])
150159
def test_trace_report(self, step_cls, discard):

0 commit comments

Comments
 (0)