Skip to content

Commit 08603a4

Browse files
committed
Use tqdm_notebook
1 parent 85658d7 commit 08603a4

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

pymc3/parallel_sampling.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from collections import namedtuple
88

99
import six
10-
import tqdm
1110
import numpy as np
1211

1312
from . import theanof
@@ -222,7 +221,15 @@ def terminate_all(processes, patience=2):
222221

223222
class ParallelSampler(object):
224223
def __init__(self, draws, tune, chains, cores, seeds, start_points,
225-
step_method, start_chain_num=0, progressbar=True):
224+
step_method, start_chain_num=0, progressbar=True,
225+
notebook=True):
226+
if progressbar and notebook:
227+
import tqdm
228+
tqdm_ = tqdm.tqdm_notebook
229+
elif progressbar:
230+
import tqdm
231+
tqdm_ = tqdm.tqdm
232+
226233
self._samplers = [
227234
ProcessAdapter(draws, tune, step_method,
228235
chain + start_chain_num, seed, start)
@@ -239,10 +246,10 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points,
239246

240247
self._global_progress = self._progress = None
241248
if progressbar:
242-
self._global_progress = tqdm.tqdm(
249+
self._global_progress = tqdm_(
243250
total=chains, unit='chains', position=1)
244251
self._progress = [
245-
tqdm.tqdm(
252+
tqdm_(
246253
desc=' Chain %i' % (chain + start_chain_num),
247254
unit='draws',
248255
position=chain + 2,

0 commit comments

Comments
 (0)