Skip to content

Commit e0f1118

Browse files
committed
Initialize QuadPotentialDiagApaptExp state variables in __init__
1 parent eac88c2 commit e0f1118

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pymc/step_methods/hmc/quadpotential.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,8 @@ def current_mean(self, out=None):
481481
class QuadPotentialDiagAdaptExpState(QuadPotentialDiagAdaptState):
482482
_alpha: float
483483
_stop_adaptation: float
484-
_variance_estimator: ExpWeightedVarianceState
485-
486-
_variance_estimator_grad: ExpWeightedVarianceState | None = None
484+
_variance_estimator: ExpWeightedVarianceState | None
485+
_variance_estimator_grad: ExpWeightedVarianceState | None
487486

488487

489488
class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt):
@@ -529,6 +528,8 @@ def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None
529528
if stop_adaptation is None:
530529
stop_adaptation = np.inf
531530
self._stop_adaptation = stop_adaptation
531+
self._variance_estimator = None
532+
self._variance_estimator_grad = None
532533

533534
def update(self, sample, grad, tune):
534535
if tune and self._n_samples < self._stop_adaptation:

0 commit comments

Comments
 (0)