Skip to content

Commit 84a7902

Browse files
authored
Add multiplier for adaptation window sizes (#3705)
* Add adaptation window multiplier to full mass matrix * Add adaptation window multiplier to diag adapt * Set default value of adaptatation window multiplier to be an int * Clean up extraneous variable * Fix tests with multiplier * Update docstring and remove extraneous parameter * Fix tests * Make multiplier float and cast to int after multiplying
1 parent 6c1824d commit 84a7902

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

pymc3/step_methods/hmc/quadpotential.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,16 @@ def isquadpotential(value):
126126
class QuadPotentialDiagAdapt(QuadPotential):
127127
"""Adapt a diagonal mass matrix from the sample variances."""
128128

129-
def __init__(self, n, initial_mean, initial_diag=None, initial_weight=0,
130-
adaptation_window=101, dtype=None):
129+
def __init__(
130+
self,
131+
n,
132+
initial_mean,
133+
initial_diag=None,
134+
initial_weight=0,
135+
adaptation_window=101,
136+
adaptation_window_multiplier=1,
137+
dtype=None,
138+
):
131139
"""Set up a diagonal mass matrix."""
132140
if initial_diag is not None and initial_diag.ndim != 1:
133141
raise ValueError('Initial diagonal must be one-dimensional.')
@@ -158,6 +166,7 @@ def __init__(self, n, initial_mean, initial_diag=None, initial_weight=0,
158166
self._background_var = _WeightedVariance(self._n, dtype=self.dtype)
159167
self._n_samples = 0
160168
self.adaptation_window = adaptation_window
169+
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
161170

162171
def velocity(self, x, out=None):
163172
"""Compute the current velocity at a position in parameter space."""
@@ -190,15 +199,14 @@ def update(self, sample, grad, tune):
190199
if not tune:
191200
return
192201

193-
window = self.adaptation_window
194-
195202
self._foreground_var.add_sample(sample, weight=1)
196203
self._background_var.add_sample(sample, weight=1)
197204
self._update_from_weightvar(self._foreground_var)
198205

199-
if self._n_samples > 0 and self._n_samples % window == 0:
206+
if self._n_samples > 0 and self._n_samples % self.adaptation_window == 0:
200207
self._foreground_var = self._background_var
201208
self._background_var = _WeightedVariance(self._n, dtype=self.dtype)
209+
self.adaptation_window = int(self.adaptation_window * self.adaptation_window_multiplier)
202210

203211
self._n_samples += 1
204212

@@ -458,22 +466,16 @@ def velocity_energy(self, x, v_out):
458466

459467

460468
class QuadPotentialFullAdapt(QuadPotentialFull):
461-
"""Adapt a dense mass matrix using the sample covariances
462-
463-
If the parameter ``doubling`` is true, the adaptation window is doubled
464-
every time it is passed. This can lead to better convergence of the mass
465-
matrix estimation.
466-
467-
"""
469+
"""Adapt a dense mass matrix using the sample covariances."""
468470
def __init__(
469471
self,
470472
n,
471473
initial_mean,
472474
initial_cov=None,
473475
initial_weight=0,
474476
adaptation_window=101,
477+
adaptation_window_multiplier=2,
475478
update_window=1,
476-
doubling=True,
477479
dtype=None,
478480
):
479481
warnings.warn("QuadPotentialFullAdapt is an experimental feature")
@@ -511,8 +513,8 @@ def __init__(
511513
self._background_cov = _WeightedCovariance(self._n, dtype=self.dtype)
512514
self._n_samples = 0
513515

514-
self._doubling = doubling
515516
self._adaptation_window = int(adaptation_window)
517+
self._adaptation_window_multiplier = float(adaptation_window_multiplier)
516518
self._update_window = int(update_window)
517519
self._previous_update = 0
518520

@@ -538,16 +540,16 @@ def update(self, sample, grad, tune):
538540
if (delta + 1) % self._update_window == 0:
539541
self._update_from_weightvar(self._foreground_cov)
540542

541-
# Reset the background covariance if we are at the end of the adaptation window.
543+
# Reset the background covariance if we are at the end of the adaptation
544+
# window.
542545
if delta >= self._adaptation_window:
543546
self._foreground_cov = self._background_cov
544547
self._background_cov = _WeightedCovariance(
545548
self._n, dtype=self.dtype
546549
)
547550

548551
self._previous_update = self._n_samples
549-
if self._doubling:
550-
self._adaptation_window *= 2
552+
self._adaptation_window = int(self._adaptation_window * self._adaptation_window_multiplier)
551553

552554
self._n_samples += 1
553555

pymc3/tests/test_quadpotential.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,15 @@ def test_full_adapt_adaptation_window(seed=8978):
225225
for i in range(window + 1):
226226
pot.update(np.random.randn(2), None, True)
227227
assert pot._previous_update == window
228-
assert pot._adaptation_window == window * 2
228+
assert pot._adaptation_window == window * pot._adaptation_window_multiplier
229229

230230
pot = quadpotential.QuadPotentialFullAdapt(
231-
2, np.zeros(2), np.eye(2), 1, adaptation_window=window, doubling=False
231+
2, np.zeros(2), np.eye(2), 1, adaptation_window=window
232232
)
233233
for i in range(window + 1):
234234
pot.update(np.random.randn(2), None, True)
235235
assert pot._previous_update == window
236-
assert pot._adaptation_window == window
236+
assert pot._adaptation_window == window * pot._adaptation_window_multiplier
237237

238238

239239
def test_full_adapt_not_invertible():

0 commit comments

Comments
 (0)