Skip to content

Replaced tqdm progressbar with fastprogress #3667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 37 additions & 38 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,31 @@

def _get_broken_pipe_exception():
import sys
if sys.platform == 'win32':
return RuntimeError("The communication pipe between the main process "
"and its spawned children is broken.\n"
"In Windows OS, this usually means that the child "
"process raised an exception while it was being "
"spawned, before it was setup to communicate to "
"the main process.\n"
"The exceptions raised by the child process while "
"spawning cannot be caught or handled from the "
"main process, and when running from an IPython or "
"jupyter notebook interactive kernel, the child's "
"exception and traceback appears to be lost.\n"
"A known way to see the child's error, and try to "
"fix or handle it, is to run the problematic code "
"as a batch script from a system's Command Prompt. "
"The child's exception will be printed to the "
"Command Promt's stderr, and it should be visible "
"above this error and traceback.\n"
"Note that if running a jupyter notebook that was "
"invoked from a Command Prompt, the child's "
"exception should have been printed to the Command "
"Prompt on which the notebook is running.")

if sys.platform == "win32":
return RuntimeError(
"The communication pipe between the main process "
"and its spawned children is broken.\n"
"In Windows OS, this usually means that the child "
"process raised an exception while it was being "
"spawned, before it was setup to communicate to "
"the main process.\n"
"The exceptions raised by the child process while "
"spawning cannot be caught or handled from the "
"main process, and when running from an IPython or "
"jupyter notebook interactive kernel, the child's "
"exception and traceback appears to be lost.\n"
"A known way to see the child's error, and try to "
"fix or handle it, is to run the problematic code "
"as a batch script from a system's Command Prompt. "
"The child's exception will be printed to the "
"Command Promt's stderr, and it should be visible "
"above this error and traceback.\n"
"Note that if running a jupyter notebook that was "
"invoked from a Command Prompt, the child's "
"exception should have been printed to the Command "
"Prompt on which the notebook is running."
)
else:
return None

Expand Down Expand Up @@ -237,7 +240,6 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
tune,
seed,
)
# We fork right away, so that the main process can start tqdm threads
try:
self._process.start()
except IOError as e:
Expand Down Expand Up @@ -346,8 +348,7 @@ def __init__(
start_chain_num=0,
progressbar=True,
):
if progressbar:
from tqdm import tqdm
from fastprogress import progress_bar

if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError("Number of seeds and start_points must be %s." % chains)
Expand All @@ -369,14 +370,13 @@ def __init__(

self._progress = None
self._divergences = 0
self._total_draws = 0
self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences"
self._chains = chains
if progressbar:
self._progress = tqdm(
total=chains * (draws + tune),
unit="draws",
desc=self._desc.format(self)
)
self._progress = progress_bar(
range(chains * (draws + tune)), display=progressbar, auto_update=False
)
self._progress.comment = self._desc.format(self)

def _make_active(self):
while self._inactive and len(self._active) < self._max_active:
Expand All @@ -393,11 +393,11 @@ def __iter__(self):
while self._active:
draw = ProcessAdapter.recv_draw(self._active)
proc, is_last, draw, tuning, stats, warns = draw
if self._progress is not None:
if not tuning and stats and stats[0].get('diverging'):
self._divergences += 1
self._progress.set_description(self._desc.format(self))
self._progress.update()
self._total_draws += 1
if not tuning and stats and stats[0].get("diverging"):
self._divergences += 1
self._progress.comment = self._desc.format(self)
self._progress.update(self._total_draws)

if is_last:
proc.join()
Expand All @@ -423,8 +423,7 @@ def __enter__(self):

def __exit__(self, *args):
ProcessAdapter.terminate_all(self._samplers)
if self._progress is not None:
self._progress.close()


def _cpu_count():
"""Try to guess the number of CPUs in the system.
Expand Down
Loading