Skip to content

Commit 22d5ecd

Browse files
authored
change convergence callback (#2098)
* change callback (cherry picked from commit a7ecea7) * add optioanal arguments to callback * add test * change name, add docs * fix test
1 parent bc69427 commit 22d5ecd

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,27 @@ def test_fit(method, kwargs, error):
361361
fit(10, method=method, **kwargs)
362362
else:
363363
fit(10, method=method, **kwargs)
364+
365+
366+
@pytest.mark.parametrize(
367+
'diff',
368+
[
369+
'relative',
370+
'absolute'
371+
]
372+
)
373+
@pytest.mark.parametrize(
374+
'ord',
375+
[1, 2, np.inf]
376+
)
377+
def test_callbacks(diff, ord):
378+
cb = pm.variational.callbacks.CheckParametersConvergence(every=1, diff=diff, ord=ord)
379+
380+
class _approx:
381+
params = (theano.shared(np.asarray([1, 2, 3])), )
382+
383+
approx = _approx()
384+
385+
with pytest.raises(StopIteration):
386+
cb(approx, None, 1)
387+
cb(approx, None, 10)

pymc3/variational/callbacks.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,63 @@ def __call__(self, approx, loss, i):
1111
raise NotImplementedError
1212

1313

14+
def relative(current, prev, eps=1e-6):
15+
return (np.abs(current - prev)+eps)/(np.abs(prev)+eps)
16+
17+
18+
def absolute(current, prev):
19+
return np.abs(current - prev)
20+
21+
_diff = dict(
22+
relative=relative,
23+
absolute=absolute
24+
)
25+
26+
1427
class CheckParametersConvergence(Callback):
15-
def __init__(self, every=1000, tolerance=1e-3, eps=1e-10):
28+
"""Convergence stopping check
29+
30+
Parameters
31+
----------
32+
every : int
33+
check frequency
34+
tolerance : float
35+
if diff norm < tolerance : break
36+
diff : str
37+
difference type one of {'absolute', 'relative'}
38+
ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional
39+
see more info in :func:`numpy.linalg.norm`
40+
41+
Examples
42+
--------
43+
>>> with model:
44+
... approx = pm.fit(
45+
... n=10000, callbacks=[
46+
... CheckParametersConvergence(
47+
... every=50, diff='absolute',
48+
... tolerance=1e-4)
49+
... ]
50+
... )
51+
"""
52+
53+
def __init__(self, every=1000, tolerance=1e-3, diff='relative', ord=np.inf):
54+
self._diff = _diff[diff]
55+
self.ord = ord
1656
self.every = every
1757
self.prev = None
1858
self.tolerance = tolerance
19-
self.eps = np.float32(eps)
2059

2160
def __call__(self, approx, _, i):
2261
if self.prev is None:
2362
self.prev = self.flatten_shared(approx.params)
63+
return
2464
if i % self.every or i < self.every:
2565
return
2666
current = self.flatten_shared(approx.params)
2767
prev = self.prev
28-
eps = self.eps
29-
delta = (np.abs(current - prev)+eps)/(np.abs(prev)+eps)
68+
delta = self._diff(current, prev) # type: np.ndarray
3069
self.prev = current
31-
norm = delta.max()
70+
norm = np.linalg.norm(delta, self.ord)
3271
if norm < self.tolerance:
3372
raise StopIteration('Convergence archived at %d' % i)
3473

0 commit comments

Comments
 (0)