From b7990833a25a386e8a251f2ccd14559ea687b5f7 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 1 May 2017 02:37:46 +0300 Subject: [PATCH 1/5] change callback (cherry picked from commit a7ecea7) --- pymc3/variational/callbacks.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pymc3/variational/callbacks.py b/pymc3/variational/callbacks.py index e3caa6013a..d09186b02c 100644 --- a/pymc3/variational/callbacks.py +++ b/pymc3/variational/callbacks.py @@ -12,11 +12,10 @@ def __call__(self, approx, loss, i): class CheckParametersConvergence(Callback): - def __init__(self, every=1000, tolerance=1e-3, eps=1e-10): + def __init__(self, every=1000, tolerance=1e-3): self.every = every self.prev = None self.tolerance = tolerance - self.eps = np.float32(eps) def __call__(self, approx, _, i): if self.prev is None: @@ -25,8 +24,7 @@ def __call__(self, approx, _, i): return current = self.flatten_shared(approx.params) prev = self.prev - eps = self.eps - delta = (np.abs(current - prev)+eps)/(np.abs(prev)+eps) + delta = np.abs(current - prev) self.prev = current norm = delta.max() if norm < self.tolerance: From beb3f925b50cd65806e6f39f65c6a11379543952 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 12 May 2017 14:53:11 +0300 Subject: [PATCH 2/5] add optioanal arguments to callback --- pymc3/variational/callbacks.py | 36 +++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/pymc3/variational/callbacks.py b/pymc3/variational/callbacks.py index d09186b02c..4f25e0eef1 100644 --- a/pymc3/variational/callbacks.py +++ b/pymc3/variational/callbacks.py @@ -11,8 +11,37 @@ def __call__(self, approx, loss, i): raise NotImplementedError +def percentage(current, prev, eps=1e-6): + return (np.abs(current - prev)+eps)/(np.abs(prev)+eps) + + +def absolute(current, prev): + return np.abs(current - prev) + +_diff = dict( + percentage=percentage, + absolute=absolute +) + + class CheckParametersConvergence(Callback): - def __init__(self, every=1000, tolerance=1e-3): + """ + + Parameters + ---------- + every : int + check frequency + tolerance : float + if diff norm < tolerance : break + diff : str + difference type one of {'absolute', 'percentage'} + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + see more info in np.linalg.norm + """ + + def __init__(self, every=1000, tolerance=1e-3, diff='percentage', ord=np.inf): + self._diff = _diff[diff] + self.ord = ord self.every = every self.prev = None self.tolerance = tolerance @@ -20,13 +49,14 @@ def __init__(self, every=1000, tolerance=1e-3): def __call__(self, approx, _, i): if self.prev is None: self.prev = self.flatten_shared(approx.params) + return if i % self.every or i < self.every: return current = self.flatten_shared(approx.params) prev = self.prev - delta = np.abs(current - prev) + delta = self._diff(current - prev) # type: np.ndarray self.prev = current - norm = delta.max() + norm = np.linalg.norm(delta, self.ord) if norm < self.tolerance: raise StopIteration('Convergence archived at %d' % i) From c6fd7597fca2c0623ffe030a991a79ebda27609a Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 12 May 2017 15:21:15 +0300 Subject: [PATCH 3/5] add test --- pymc3/tests/test_variational_inference.py | 24 +++++++++++++++++++++++ pymc3/variational/callbacks.py | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/pymc3/tests/test_variational_inference.py b/pymc3/tests/test_variational_inference.py index 3ce98ffc14..a277f5133b 100644 --- a/pymc3/tests/test_variational_inference.py +++ b/pymc3/tests/test_variational_inference.py @@ -361,3 +361,27 @@ def test_fit(method, kwargs, error): fit(10, method=method, **kwargs) else: fit(10, method=method, **kwargs) + + +@pytest.mark.parametrize( + 'diff', + [ + 'percentage', + 'absolute' + ] +) +@pytest.mark.parametrize( + 'ord', + [1, 2, np.inf] +) +def test_callbacks(diff, ord): + cb = pm.variational.callbacks.CheckParametersConvergence(every=1, diff=diff, ord=ord) + + class _approx: + params = (theano.shared(np.asarray([1, 2, 3])), ) + + approx = _approx() + + with pytest.raises(StopIteration): + cb(approx, None, 1) + cb(approx, None, 10) diff --git a/pymc3/variational/callbacks.py b/pymc3/variational/callbacks.py index 4f25e0eef1..dcec7c5e71 100644 --- a/pymc3/variational/callbacks.py +++ b/pymc3/variational/callbacks.py @@ -54,7 +54,7 @@ def __call__(self, approx, _, i): return current = self.flatten_shared(approx.params) prev = self.prev - delta = self._diff(current - prev) # type: np.ndarray + delta = self._diff(current, prev) # type: np.ndarray self.prev = current norm = np.linalg.norm(delta, self.ord) if norm < self.tolerance: From a0e06abc7812ef0f492346d18a728cd7d226f9f4 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sat, 13 May 2017 00:56:34 +0300 Subject: [PATCH 4/5] change name, add docs --- pymc3/variational/callbacks.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/pymc3/variational/callbacks.py b/pymc3/variational/callbacks.py index dcec7c5e71..0b2728d6c9 100644 --- a/pymc3/variational/callbacks.py +++ b/pymc3/variational/callbacks.py @@ -11,7 +11,7 @@ def __call__(self, approx, loss, i): raise NotImplementedError -def percentage(current, prev, eps=1e-6): +def relative(current, prev, eps=1e-6): return (np.abs(current - prev)+eps)/(np.abs(prev)+eps) @@ -19,13 +19,13 @@ def absolute(current, prev): return np.abs(current - prev) _diff = dict( - percentage=percentage, + relative=relative, absolute=absolute ) class CheckParametersConvergence(Callback): - """ + """Convergence stopping check Parameters ---------- @@ -34,12 +34,23 @@ class CheckParametersConvergence(Callback): tolerance : float if diff norm < tolerance : break diff : str - difference type one of {'absolute', 'percentage'} + difference type one of {'absolute', 'relative'} ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - see more info in np.linalg.norm + see more info in :func:`numpy.linalg.norm` + + Examples + -------- + >>> with model: + ... approx = pm.fit( + ... n=10000, callbacks=[ + ... CheckParametersConvergence( + ... every=50, diff='absolute', + ... tolerance=1e-4) + ... ] + ... ) """ - def __init__(self, every=1000, tolerance=1e-3, diff='percentage', ord=np.inf): + def __init__(self, every=1000, tolerance=1e-3, diff='relative', ord=np.inf): self._diff = _diff[diff] self.ord = ord self.every = every From 4fdd4b908c7a05f07c52d3ff877b58884b974249 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sat, 13 May 2017 10:17:44 +0300 Subject: [PATCH 5/5] fix test --- pymc3/tests/test_variational_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/tests/test_variational_inference.py b/pymc3/tests/test_variational_inference.py index a277f5133b..10e0635f57 100644 --- a/pymc3/tests/test_variational_inference.py +++ b/pymc3/tests/test_variational_inference.py @@ -366,7 +366,7 @@ def test_fit(method, kwargs, error): @pytest.mark.parametrize( 'diff', [ - 'percentage', + 'relative', 'absolute' ] )