@@ -966,8 +966,6 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds):
966
966
967
967
968
968
def _mp_sample (** kwargs ):
969
- import sys
970
-
971
969
cores = kwargs .pop ('cores' )
972
970
chain = kwargs .pop ('chain' )
973
971
rseed = kwargs .pop ('random_seed' )
@@ -978,15 +976,21 @@ def _mp_sample(**kwargs):
978
976
step = kwargs .pop ('step' )
979
977
progressbar = kwargs .pop ('progressbar' )
980
978
use_mmap = kwargs .pop ('use_mmap' )
979
+ model = kwargs .pop ('model' , None )
980
+ trace = kwargs .pop ('trace' , None )
981
981
982
982
if sys .version_info .major >= 3 :
983
983
import pymc3 .parallel_sampling as ps
984
984
985
- model = modelcontext (kwargs .pop ('model' , None ))
986
- trace = kwargs .pop ('trace' , None )
985
+ # We did draws += tune in pm.sample
986
+ draws -= tune
987
+
987
988
traces = []
988
989
for idx in range (chain , chain + chains ):
989
- strace = _choose_backend (trace , idx , model = model )
990
+ if trace is not None :
991
+ strace = _choose_backend (copy (trace ), idx , model = model )
992
+ else :
993
+ strace = _choose_backend (None , idx , model = model )
990
994
# TODO what is this for?
991
995
update_start_vals (start [idx - chain ], model .test_point , model )
992
996
if step .generates_stats and strace .supports_sampler_stats :
@@ -997,20 +1001,27 @@ def _mp_sample(**kwargs):
997
1001
998
1002
sampler = ps .ParallelSampler (
999
1003
draws , tune , chains , cores , rseed , start , step , chain , progressbar )
1000
- with sampler :
1001
- for draw in sampler :
1002
- trace = traces [draw .chain - chain ]
1003
- if trace .supports_sampler_stats and draw .stats is not None :
1004
- trace .record (draw .point , draw .stats )
1005
- else :
1006
- trace .record (draw .point )
1007
- if draw .is_last :
1008
- trace .close ()
1009
- return MultiTrace (traces )
1004
+ try :
1005
+ with sampler :
1006
+ for draw in sampler :
1007
+ trace = traces [draw .chain - chain ]
1008
+ if trace .supports_sampler_stats and draw .stats is not None :
1009
+ trace .record (draw .point , draw .stats )
1010
+ else :
1011
+ trace .record (draw .point )
1012
+ if draw .is_last :
1013
+ trace .close ()
1014
+ return MultiTrace (traces )
1015
+ except KeyboardInterrupt :
1016
+ traces , length = _choose_chains (traces , tune )
1017
+ return MultiTrace (traces )[:length ]
1018
+ finally :
1019
+ for trace in traces :
1020
+ trace .close ()
1010
1021
1011
1022
else :
1012
1023
chain_nums = list (range (chain , chain + chains ))
1013
- pbars = [kwargs . pop ( ' progressbar' ) ] + [False ] * (chains - 1 )
1024
+ pbars = [progressbar ] + [False ] * (chains - 1 )
1014
1025
jobs = (delayed (_sample )(* args , ** kwargs )
1015
1026
for args in zip (chain_nums , pbars , rseed , start ))
1016
1027
if use_mmap :
@@ -1020,6 +1031,35 @@ def _mp_sample(**kwargs):
1020
1031
return MultiTrace (traces )
1021
1032
1022
1033
1034
+ def _choose_chains (traces , tune ):
1035
+ if tune is None :
1036
+ tune = 0
1037
+
1038
+ if not traces :
1039
+ return []
1040
+
1041
+ lengths = [max (0 , len (trace ) - tune ) for trace in traces ]
1042
+ if not sum (lengths ):
1043
+ raise ValueError ('Not enough samples to build a trace.' )
1044
+
1045
+ idxs = np .argsort (lengths )[::- 1 ]
1046
+ l_sort = np .array (lengths )[idxs ]
1047
+
1048
+ final_length = l_sort [0 ]
1049
+ last_total = 0
1050
+ for i , length in enumerate (l_sort ):
1051
+ total = (i + 1 ) * length
1052
+ if total < last_total :
1053
+ use_until = i
1054
+ break
1055
+ last_total = total
1056
+ final_length = length
1057
+ else :
1058
+ use_until = len (lengths )
1059
+
1060
+ return [traces [idx ] for idx in idxs [:use_until ]], final_length + tune
1061
+
1062
+
1023
1063
def stop_tuning (step ):
1024
1064
""" stop tuning the current step method """
1025
1065
0 commit comments