Skip to content

Commit ee7c5bc

Browse files
michaelosthegeJunpeng Lao
authored and
Junpeng Lao
committed
PEP8 style, daemonizing child processes (#2747)
1 parent cff9ea9 commit ee7c5bc

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

pymc3/sampling.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -654,14 +654,16 @@ def __init__(self, steppers, parallelize):
654654
# configure a child process for each stepper
655655
pm._log.info('Attempting to parallelize chains.')
656656
import multiprocessing
657-
for c,stepper in enumerate(steppers):
657+
for c, stepper in enumerate(tqdm(steppers)):
658658
slave_end, master_end = multiprocessing.Pipe()
659659
stepper_dumps = pickle.dumps(stepper, protocol=4)
660660
process = multiprocessing.Process(
661661
target=self.__class__._run_slave,
662662
args=(c, stepper_dumps, slave_end),
663663
name='ChainWalker{}'.format(c)
664664
)
665+
# we want the child process to exit if the parent is terminated
666+
process.daemon = True
665667
# Starting the process might fail and takes time.
666668
# By doing it in the constructor, the sampling progress bar
667669
# will not be confused by the process start.
@@ -794,7 +796,7 @@ def _prepare_iter_population(draws, chains, step, start, parallelize, tune=None,
794796

795797
# 1. prepare a BaseTrace for each chain
796798
traces = [_choose_backend(None, chain, model=model) for chain in chains]
797-
for c,strace in enumerate(traces):
799+
for c, strace in enumerate(traces):
798800
# initialize the trace size and variable transforms
799801
if len(strace) > 0:
800802
update_start_vals(start[c], strace.point(-1), model)
@@ -860,7 +862,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
860862
updates = popstep.step(i == tune, points)
861863

862864
# apply the update to the points and record to the traces
863-
for c,strace in enumerate(traces):
865+
for c, strace in enumerate(traces):
864866
if steppers[c].generates_stats:
865867
points[c], states = updates[c]
866868
if strace.supports_sampler_stats:
@@ -873,17 +875,17 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
873875
# yield the state of all chains in parallel
874876
yield traces
875877
except KeyboardInterrupt:
876-
for c,strace in enumerate(traces):
878+
for c, strace in enumerate(traces):
877879
strace.close()
878880
if hasattr(steppers[c], 'report'):
879881
steppers[c].report._finalize(strace)
880882
raise
881883
except BaseException:
882-
for c,strace in enumerate(traces):
884+
for c, strace in enumerate(traces):
883885
strace.close()
884886
raise
885887
else:
886-
for c,strace in enumerate(traces):
888+
for c, strace in enumerate(traces):
887889
strace.close()
888890
if hasattr(steppers[c], 'report'):
889891
steppers[c].report._finalize(strace)

0 commit comments

Comments
 (0)