@@ -171,16 +171,24 @@ def __init__(
171
171
172
172
self .dtype = dtype
173
173
self ._n = n
174
- self ._var = np .array (initial_diag , dtype = self .dtype , copy = True )
174
+
175
+ self ._initial_mean = initial_mean
176
+ self ._initial_diag = initial_diag
177
+ self ._initial_weight = initial_weight
178
+ self .adaptation_window = adaptation_window
179
+ self .adaptation_window_multiplier = float (adaptation_window_multiplier )
180
+
181
+ self .reset ()
182
+
183
+ def reset (self ):
184
+ self ._var = np .array (self ._initial_diag , dtype = self .dtype , copy = True )
175
185
self ._var_theano = theano .shared (self ._var )
176
- self ._stds = np .sqrt (initial_diag )
186
+ self ._stds = np .sqrt (self . _initial_diag )
177
187
self ._inv_stds = floatX (1. ) / self ._stds
178
188
self ._foreground_var = _WeightedVariance (
179
- self ._n , initial_mean , initial_diag , initial_weight , self .dtype )
189
+ self ._n , self . _initial_mean , self . _initial_diag , self . _initial_weight , self .dtype )
180
190
self ._background_var = _WeightedVariance (self ._n , dtype = self .dtype )
181
191
self ._n_samples = 0
182
- self .adaptation_window = adaptation_window
183
- self .adaptation_window_multiplier = float (adaptation_window_multiplier )
184
192
185
193
def velocity (self , x , out = None ):
186
194
"""Compute the current velocity at a position in parameter space."""
@@ -275,8 +283,8 @@ class QuadPotentialDiagAdaptGrad(QuadPotentialDiagAdapt):
275
283
This is experimental, and may be removed without prior deprication.
276
284
"""
277
285
278
- def __init__ (self , * args , ** kwargs ):
279
- super ().__init__ ( * args , ** kwargs )
286
+ def reset (self ):
287
+ super ().reset ( )
280
288
self ._grads1 = np .zeros (self ._n , dtype = self .dtype )
281
289
self ._ngrads1 = 0
282
290
self ._grads2 = np .zeros (self ._n , dtype = self .dtype )
@@ -518,20 +526,27 @@ def __init__(
518
526
519
527
self .dtype = dtype
520
528
self ._n = n
521
- self ._cov = np .array (initial_cov , dtype = self .dtype , copy = True )
529
+ self ._initial_mean = initial_mean
530
+ self ._initial_cov = initial_cov
531
+ self ._initial_weight = initial_weight
532
+
533
+ self .adaptation_window = int (adaptation_window )
534
+ self .adaptation_window_multiplier = float (adaptation_window_multiplier )
535
+ self ._update_window = int (update_window )
536
+
537
+ self .reset ()
538
+
539
+ def reset (self ):
540
+ self ._previous_update = 0
541
+ self ._cov = np .array (self ._initial_cov , dtype = self .dtype , copy = True )
522
542
self ._chol = scipy .linalg .cholesky (self ._cov , lower = True )
523
543
self ._chol_error = None
524
544
self ._foreground_cov = _WeightedCovariance (
525
- self ._n , initial_mean , initial_cov , initial_weight , self .dtype
545
+ self ._n , self . _initial_mean , self . _initial_cov , self . _initial_weight , self .dtype
526
546
)
527
547
self ._background_cov = _WeightedCovariance (self ._n , dtype = self .dtype )
528
548
self ._n_samples = 0
529
549
530
- self .adaptation_window = int (adaptation_window )
531
- self .adaptation_window_multiplier = float (adaptation_window_multiplier )
532
- self ._update_window = int (update_window )
533
- self ._previous_update = 0
534
-
535
550
def _update_from_weightvar (self , weightvar ):
536
551
weightvar .current_covariance (out = self ._cov )
537
552
try :
0 commit comments