@@ -11,24 +11,63 @@ def __call__(self, approx, loss, i):
11
11
raise NotImplementedError
12
12
13
13
14
+ def relative (current , prev , eps = 1e-6 ):
15
+ return (np .abs (current - prev )+ eps )/ (np .abs (prev )+ eps )
16
+
17
+
18
+ def absolute (current , prev ):
19
+ return np .abs (current - prev )
20
+
21
+ _diff = dict (
22
+ relative = relative ,
23
+ absolute = absolute
24
+ )
25
+
26
+
14
27
class CheckParametersConvergence (Callback ):
15
- def __init__ (self , every = 1000 , tolerance = 1e-3 , eps = 1e-10 ):
28
+ """Convergence stopping check
29
+
30
+ Parameters
31
+ ----------
32
+ every : int
33
+ check frequency
34
+ tolerance : float
35
+ if diff norm < tolerance : break
36
+ diff : str
37
+ difference type one of {'absolute', 'relative'}
38
+ ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional
39
+ see more info in :func:`numpy.linalg.norm`
40
+
41
+ Examples
42
+ --------
43
+ >>> with model:
44
+ ... approx = pm.fit(
45
+ ... n=10000, callbacks=[
46
+ ... CheckParametersConvergence(
47
+ ... every=50, diff='absolute',
48
+ ... tolerance=1e-4)
49
+ ... ]
50
+ ... )
51
+ """
52
+
53
+ def __init__ (self , every = 1000 , tolerance = 1e-3 , diff = 'relative' , ord = np .inf ):
54
+ self ._diff = _diff [diff ]
55
+ self .ord = ord
16
56
self .every = every
17
57
self .prev = None
18
58
self .tolerance = tolerance
19
- self .eps = np .float32 (eps )
20
59
21
60
def __call__ (self , approx , _ , i ):
22
61
if self .prev is None :
23
62
self .prev = self .flatten_shared (approx .params )
63
+ return
24
64
if i % self .every or i < self .every :
25
65
return
26
66
current = self .flatten_shared (approx .params )
27
67
prev = self .prev
28
- eps = self .eps
29
- delta = (np .abs (current - prev )+ eps )/ (np .abs (prev )+ eps )
68
+ delta = self ._diff (current , prev ) # type: np.ndarray
30
69
self .prev = current
31
- norm = delta . max ( )
70
+ norm = np . linalg . norm ( delta , self . ord )
32
71
if norm < self .tolerance :
33
72
raise StopIteration ('Convergence archived at %d' % i )
34
73
0 commit comments