Skip to content

Commit ebb3b3e

Browse files
committed
Show one progress bar for all chains
1 parent 212ff07 commit ebb3b3e

File tree

1 file changed

+8
-25
lines changed

1 file changed

+8
-25
lines changed

pymc3/parallel_sampling.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,8 @@ def terminate_all(processes, patience=2):
234234

235235
class ParallelSampler(object):
236236
def __init__(self, draws, tune, chains, cores, seeds, start_points,
237-
step_method, start_chain_num=0, progressbar=True,
238-
notebook=True):
239-
if progressbar and notebook:
240-
import tqdm
241-
tqdm_ = tqdm.tqdm_notebook
242-
elif progressbar:
237+
step_method, start_chain_num=0, progressbar=True):
238+
if progressbar:
243239
import tqdm
244240
tqdm_ = tqdm.tqdm
245241

@@ -257,18 +253,11 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points,
257253
self._in_context = False
258254
self._start_chain_num = start_chain_num
259255

260-
self._global_progress = self._progress = None
256+
self._progress = None
261257
if progressbar:
262-
self._global_progress = tqdm_(
263-
total=chains, unit='chains', position=0)
264-
self._progress = [
265-
tqdm_(
266-
desc=' Chain %i' % (chain + start_chain_num),
267-
unit='draws',
268-
position=chain + 1,
269-
total=draws + tune)
270-
for chain in range(chains)
271-
]
258+
self._progress = tqdm_(
259+
total=chains * (draws + tune), unit='draws',
260+
desc='Sampling %s chains' % chains)
272261

273262
def _make_active(self):
274263
while self._inactive and len(self._active) < self._max_active:
@@ -286,17 +275,13 @@ def __iter__(self):
286275
draw = ProcessAdapter.recv_draw(self._active)
287276
proc, is_last, draw, tuning, stats, warns = draw
288277
if self._progress is not None:
289-
self._progress[proc.chain - self._start_chain_num].update()
278+
self._progress.update()
290279

291280
if is_last:
292281
proc.join()
293282
self._active.remove(proc)
294283
self._finished.append(proc)
295284
self._make_active()
296-
if self._global_progress is not None:
297-
self._global_progress.update()
298-
if self._progress is not None:
299-
self._progress[proc.chain - self._start_chain_num].close()
300285

301286
# We could also yield proc.shared_point_view directly,
302287
# and only call proc.write_next() after the yield returns.
@@ -318,6 +303,4 @@ def __enter__(self):
318303
def __exit__(self, *args):
319304
ProcessAdapter.terminate_all(self._samplers)
320305
if self._progress is not None:
321-
self._global_progress.close()
322-
for progress in self._progress:
323-
progress.close()
306+
self._progress.close()

0 commit comments

Comments
 (0)