@@ -126,8 +126,16 @@ def isquadpotential(value):
126
126
class QuadPotentialDiagAdapt (QuadPotential ):
127
127
"""Adapt a diagonal mass matrix from the sample variances."""
128
128
129
- def __init__ (self , n , initial_mean , initial_diag = None , initial_weight = 0 ,
130
- adaptation_window = 101 , dtype = None ):
129
+ def __init__ (
130
+ self ,
131
+ n ,
132
+ initial_mean ,
133
+ initial_diag = None ,
134
+ initial_weight = 0 ,
135
+ adaptation_window = 101 ,
136
+ adaptation_window_multiplier = 1 ,
137
+ dtype = None ,
138
+ ):
131
139
"""Set up a diagonal mass matrix."""
132
140
if initial_diag is not None and initial_diag .ndim != 1 :
133
141
raise ValueError ('Initial diagonal must be one-dimensional.' )
@@ -158,6 +166,7 @@ def __init__(self, n, initial_mean, initial_diag=None, initial_weight=0,
158
166
self ._background_var = _WeightedVariance (self ._n , dtype = self .dtype )
159
167
self ._n_samples = 0
160
168
self .adaptation_window = adaptation_window
169
+ self .adaptation_window_multiplier = float (adaptation_window_multiplier )
161
170
162
171
def velocity (self , x , out = None ):
163
172
"""Compute the current velocity at a position in parameter space."""
@@ -190,15 +199,14 @@ def update(self, sample, grad, tune):
190
199
if not tune :
191
200
return
192
201
193
- window = self .adaptation_window
194
-
195
202
self ._foreground_var .add_sample (sample , weight = 1 )
196
203
self ._background_var .add_sample (sample , weight = 1 )
197
204
self ._update_from_weightvar (self ._foreground_var )
198
205
199
- if self ._n_samples > 0 and self ._n_samples % window == 0 :
206
+ if self ._n_samples > 0 and self ._n_samples % self . adaptation_window == 0 :
200
207
self ._foreground_var = self ._background_var
201
208
self ._background_var = _WeightedVariance (self ._n , dtype = self .dtype )
209
+ self .adaptation_window = int (self .adaptation_window * self .adaptation_window_multiplier )
202
210
203
211
self ._n_samples += 1
204
212
@@ -458,22 +466,16 @@ def velocity_energy(self, x, v_out):
458
466
459
467
460
468
class QuadPotentialFullAdapt (QuadPotentialFull ):
461
- """Adapt a dense mass matrix using the sample covariances
462
-
463
- If the parameter ``doubling`` is true, the adaptation window is doubled
464
- every time it is passed. This can lead to better convergence of the mass
465
- matrix estimation.
466
-
467
- """
469
+ """Adapt a dense mass matrix using the sample covariances."""
468
470
def __init__ (
469
471
self ,
470
472
n ,
471
473
initial_mean ,
472
474
initial_cov = None ,
473
475
initial_weight = 0 ,
474
476
adaptation_window = 101 ,
477
+ adaptation_window_multiplier = 2 ,
475
478
update_window = 1 ,
476
- doubling = True ,
477
479
dtype = None ,
478
480
):
479
481
warnings .warn ("QuadPotentialFullAdapt is an experimental feature" )
@@ -511,8 +513,8 @@ def __init__(
511
513
self ._background_cov = _WeightedCovariance (self ._n , dtype = self .dtype )
512
514
self ._n_samples = 0
513
515
514
- self ._doubling = doubling
515
516
self ._adaptation_window = int (adaptation_window )
517
+ self ._adaptation_window_multiplier = float (adaptation_window_multiplier )
516
518
self ._update_window = int (update_window )
517
519
self ._previous_update = 0
518
520
@@ -538,16 +540,16 @@ def update(self, sample, grad, tune):
538
540
if (delta + 1 ) % self ._update_window == 0 :
539
541
self ._update_from_weightvar (self ._foreground_cov )
540
542
541
- # Reset the background covariance if we are at the end of the adaptation window.
543
+ # Reset the background covariance if we are at the end of the adaptation
544
+ # window.
542
545
if delta >= self ._adaptation_window :
543
546
self ._foreground_cov = self ._background_cov
544
547
self ._background_cov = _WeightedCovariance (
545
548
self ._n , dtype = self .dtype
546
549
)
547
550
548
551
self ._previous_update = self ._n_samples
549
- if self ._doubling :
550
- self ._adaptation_window *= 2
552
+ self ._adaptation_window = int (self ._adaptation_window * self ._adaptation_window_multiplier )
551
553
552
554
self ._n_samples += 1
553
555
0 commit comments