@@ -147,8 +147,8 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
147
147
# Optimization loop
148
148
elbos = np .empty (n )
149
149
divergence_flag = False
150
+ progress = trange (n )
150
151
try :
151
- progress = trange (n )
152
152
uw_i , elbo_current = f ()
153
153
if np .isnan (elbo_current ):
154
154
raise FloatingPointError ('NaN occurred in ADVI optimization.' )
@@ -171,12 +171,12 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
171
171
avg_delta = np .mean (circ_buff )
172
172
med_delta = np .median (circ_buff )
173
173
174
- if avg_delta < tol_obj :
174
+ if i > 0 and avg_delta < tol_obj :
175
175
pm ._log .info ('Mean ELBO converged.' )
176
176
converged = True
177
177
elbos = elbos [:(i + 1 )]
178
178
break
179
- elif med_delta < tol_obj :
179
+ elif i > 0 and med_delta < tol_obj :
180
180
pm ._log .info ('Median ELBO converged.' )
181
181
converged = True
182
182
elbos = elbos [:(i + 1 )]
@@ -186,7 +186,7 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
186
186
divergence_flag = True
187
187
else :
188
188
divergence_flag = False
189
-
189
+
190
190
except KeyboardInterrupt :
191
191
elbos = elbos [:i ]
192
192
if n < 10 :
@@ -202,10 +202,12 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
202
202
else :
203
203
avg_elbo = elbos [- n // 10 :].mean ()
204
204
pm ._log .info ('Finished [100%]: Average ELBO = {:,.5g}' .format (avg_elbo ))
205
-
205
+ finally :
206
+ progress .close ()
207
+
206
208
if divergence_flag :
207
209
pm ._log .info ('Evidence of divergence detected, inspect ELBO.' )
208
-
210
+
209
211
# Estimated parameters
210
212
l = int (uw_i .size / 2 )
211
213
u = bij .rmap (uw_i [:l ])
0 commit comments