@@ -234,12 +234,8 @@ def terminate_all(processes, patience=2):
234
234
235
235
class ParallelSampler (object ):
236
236
def __init__ (self , draws , tune , chains , cores , seeds , start_points ,
237
- step_method , start_chain_num = 0 , progressbar = True ,
238
- notebook = True ):
239
- if progressbar and notebook :
240
- import tqdm
241
- tqdm_ = tqdm .tqdm_notebook
242
- elif progressbar :
237
+ step_method , start_chain_num = 0 , progressbar = True ):
238
+ if progressbar :
243
239
import tqdm
244
240
tqdm_ = tqdm .tqdm
245
241
@@ -257,18 +253,11 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points,
257
253
self ._in_context = False
258
254
self ._start_chain_num = start_chain_num
259
255
260
- self ._global_progress = self . _progress = None
256
+ self ._progress = None
261
257
if progressbar :
262
- self ._global_progress = tqdm_ (
263
- total = chains , unit = 'chains' , position = 0 )
264
- self ._progress = [
265
- tqdm_ (
266
- desc = ' Chain %i' % (chain + start_chain_num ),
267
- unit = 'draws' ,
268
- position = chain + 1 ,
269
- total = draws + tune )
270
- for chain in range (chains )
271
- ]
258
+ self ._progress = tqdm_ (
259
+ total = chains * (draws + tune ), unit = 'draws' ,
260
+ desc = 'Sampling %s chains' % chains )
272
261
273
262
def _make_active (self ):
274
263
while self ._inactive and len (self ._active ) < self ._max_active :
@@ -286,17 +275,13 @@ def __iter__(self):
286
275
draw = ProcessAdapter .recv_draw (self ._active )
287
276
proc , is_last , draw , tuning , stats , warns = draw
288
277
if self ._progress is not None :
289
- self ._progress [ proc . chain - self . _start_chain_num ] .update ()
278
+ self ._progress .update ()
290
279
291
280
if is_last :
292
281
proc .join ()
293
282
self ._active .remove (proc )
294
283
self ._finished .append (proc )
295
284
self ._make_active ()
296
- if self ._global_progress is not None :
297
- self ._global_progress .update ()
298
- if self ._progress is not None :
299
- self ._progress [proc .chain - self ._start_chain_num ].close ()
300
285
301
286
# We could also yield proc.shared_point_view directly,
302
287
# and only call proc.write_next() after the yield returns.
@@ -318,6 +303,4 @@ def __enter__(self):
318
303
def __exit__ (self , * args ):
319
304
ProcessAdapter .terminate_all (self ._samplers )
320
305
if self ._progress is not None :
321
- self ._global_progress .close ()
322
- for progress in self ._progress :
323
- progress .close ()
306
+ self ._progress .close ()
0 commit comments