diff --git a/pymc3/tests/test_variational_inference.py b/pymc3/tests/test_variational_inference.py index 3ce98ffc14..10e0635f57 100644 --- a/pymc3/tests/test_variational_inference.py +++ b/pymc3/tests/test_variational_inference.py @@ -361,3 +361,27 @@ def test_fit(method, kwargs, error): fit(10, method=method, **kwargs) else: fit(10, method=method, **kwargs) + + +@pytest.mark.parametrize( + 'diff', + [ + 'relative', + 'absolute' + ] +) +@pytest.mark.parametrize( + 'ord', + [1, 2, np.inf] +) +def test_callbacks(diff, ord): + cb = pm.variational.callbacks.CheckParametersConvergence(every=1, diff=diff, ord=ord) + + class _approx: + params = (theano.shared(np.asarray([1, 2, 3])), ) + + approx = _approx() + + with pytest.raises(StopIteration): + cb(approx, None, 1) + cb(approx, None, 10) diff --git a/pymc3/variational/callbacks.py b/pymc3/variational/callbacks.py index e3caa6013a..0b2728d6c9 100644 --- a/pymc3/variational/callbacks.py +++ b/pymc3/variational/callbacks.py @@ -11,24 +11,63 @@ def __call__(self, approx, loss, i): raise NotImplementedError +def relative(current, prev, eps=1e-6): + return (np.abs(current - prev)+eps)/(np.abs(prev)+eps) + + +def absolute(current, prev): + return np.abs(current - prev) + +_diff = dict( + relative=relative, + absolute=absolute +) + + class CheckParametersConvergence(Callback): - def __init__(self, every=1000, tolerance=1e-3, eps=1e-10): + """Convergence stopping check + + Parameters + ---------- + every : int + check frequency + tolerance : float + if diff norm < tolerance : break + diff : str + difference type one of {'absolute', 'relative'} + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + see more info in :func:`numpy.linalg.norm` + + Examples + -------- + >>> with model: + ... approx = pm.fit( + ... n=10000, callbacks=[ + ... CheckParametersConvergence( + ... every=50, diff='absolute', + ... tolerance=1e-4) + ... ] + ... ) + """ + + def __init__(self, every=1000, tolerance=1e-3, diff='relative', ord=np.inf): + self._diff = _diff[diff] + self.ord = ord self.every = every self.prev = None self.tolerance = tolerance - self.eps = np.float32(eps) def __call__(self, approx, _, i): if self.prev is None: self.prev = self.flatten_shared(approx.params) + return if i % self.every or i < self.every: return current = self.flatten_shared(approx.params) prev = self.prev - eps = self.eps - delta = (np.abs(current - prev)+eps)/(np.abs(prev)+eps) + delta = self._diff(current, prev) # type: np.ndarray self.prev = current - norm = delta.max() + norm = np.linalg.norm(delta, self.ord) if norm < self.tolerance: raise StopIteration('Convergence archived at %d' % i)