@@ -654,14 +654,16 @@ def __init__(self, steppers, parallelize):
654
654
# configure a child process for each stepper
655
655
pm ._log .info ('Attempting to parallelize chains.' )
656
656
import multiprocessing
657
- for c ,stepper in enumerate (steppers ):
657
+ for c , stepper in enumerate (tqdm ( steppers ) ):
658
658
slave_end , master_end = multiprocessing .Pipe ()
659
659
stepper_dumps = pickle .dumps (stepper , protocol = 4 )
660
660
process = multiprocessing .Process (
661
661
target = self .__class__ ._run_slave ,
662
662
args = (c , stepper_dumps , slave_end ),
663
663
name = 'ChainWalker{}' .format (c )
664
664
)
665
+ # we want the child process to exit if the parent is terminated
666
+ process .daemon = True
665
667
# Starting the process might fail and takes time.
666
668
# By doing it in the constructor, the sampling progress bar
667
669
# will not be confused by the process start.
@@ -794,7 +796,7 @@ def _prepare_iter_population(draws, chains, step, start, parallelize, tune=None,
794
796
795
797
# 1. prepare a BaseTrace for each chain
796
798
traces = [_choose_backend (None , chain , model = model ) for chain in chains ]
797
- for c ,strace in enumerate (traces ):
799
+ for c , strace in enumerate (traces ):
798
800
# initialize the trace size and variable transforms
799
801
if len (strace ) > 0 :
800
802
update_start_vals (start [c ], strace .point (- 1 ), model )
@@ -860,7 +862,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
860
862
updates = popstep .step (i == tune , points )
861
863
862
864
# apply the update to the points and record to the traces
863
- for c ,strace in enumerate (traces ):
865
+ for c , strace in enumerate (traces ):
864
866
if steppers [c ].generates_stats :
865
867
points [c ], states = updates [c ]
866
868
if strace .supports_sampler_stats :
@@ -873,17 +875,17 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
873
875
# yield the state of all chains in parallel
874
876
yield traces
875
877
except KeyboardInterrupt :
876
- for c ,strace in enumerate (traces ):
878
+ for c , strace in enumerate (traces ):
877
879
strace .close ()
878
880
if hasattr (steppers [c ], 'report' ):
879
881
steppers [c ].report ._finalize (strace )
880
882
raise
881
883
except BaseException :
882
- for c ,strace in enumerate (traces ):
884
+ for c , strace in enumerate (traces ):
883
885
strace .close ()
884
886
raise
885
887
else :
886
- for c ,strace in enumerate (traces ):
888
+ for c , strace in enumerate (traces ):
887
889
strace .close ()
888
890
if hasattr (steppers [c ], 'report' ):
889
891
steppers [c ].report ._finalize (strace )
0 commit comments