Skip to content

Commit dd21cc4

Browse files
committed
Return partial traces if sampling is interrupted
1 parent f122872 commit dd21cc4

File tree

3 files changed

+70
-22
lines changed

3 files changed

+70
-22
lines changed

pymc3/backends/text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,9 @@ def record(self, point):
9999
self._fh.write(','.join(columns) + '\n')
100100

101101
def close(self):
102-
self._fh.close()
103-
self._fh = None # Avoid serialization issue.
102+
if self._fh is not None:
103+
self._fh.close()
104+
self._fh = None # Avoid serialization issue.
104105

105106
# Selection methods
106107

pymc3/parallel_sampling.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class _Process(multiprocessing.Process):
2828
We communicate with the main process using a pipe,
2929
and send finished samples using shared memory.
3030
"""
31-
def __init__(self, msg_pipe, step_method, shared_point, draws, tune, seed):
31+
def __init__(self, name, msg_pipe, step_method, shared_point,
32+
draws, tune, seed):
3233
super(_Process, self).__init__(daemon=True)
3334
self._msg_pipe = msg_pipe
3435
self._step_method = step_method
@@ -116,6 +117,7 @@ class ProcessAdapter(object):
116117
"""Control a Chain process from the main thread."""
117118
def __init__(self, draws, tune, step_method, chain, seed, start):
118119
self.chain = chain
120+
process_name = "worker_chain_%s" % chain
119121
self._msg_pipe, remote_conn = multiprocessing.Pipe()
120122

121123
self._shared_point = {}
@@ -138,7 +140,8 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
138140
self._num_samples = 0
139141

140142
self._process = _Process(
141-
remote_conn, step_method, self._shared_point, draws, tune, seed)
143+
process_name, remote_conn, step_method, self._shared_point,
144+
draws, tune, seed)
142145
# We fork right away, so that the main process can start tqdm threads
143146
self._process.start()
144147

@@ -185,7 +188,7 @@ def recv_draw(processes, timeout=3600):
185188
elif msg[0] == 'writing_done':
186189
proc._readable = True
187190
proc._num_samples += 1
188-
return (proc, *msg[1:])
191+
return (proc,) + msg[1:]
189192
else:
190193
raise ValueError('Sampler sent bad message.')
191194

@@ -200,7 +203,7 @@ def terminate_all(processes, patience=2):
200203
start_time = time.time()
201204
try:
202205
for process in processes:
203-
timeout = start_time + patience - time.time()
206+
timeout = time.time() + patience - start_time
204207
if timeout < 0:
205208
raise multiprocessing.TimeoutError()
206209
process.join(timeout)
@@ -285,6 +288,10 @@ def __iter__(self):
285288
if self._progress is not None:
286289
self._progress[proc.chain - self._start_chain_num].close()
287290

291+
# We could also yield proc.shared_point_view directly,
292+
# and only call proc.write_next() after the yield returns.
293+
# This seems to be faster overally though, as the worker
294+
# loses less time waiting.
288295
point = {name: val.copy()
289296
for name, val in proc.shared_point_view.items()}
290297

pymc3/sampling.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -966,8 +966,6 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds):
966966

967967

968968
def _mp_sample(**kwargs):
969-
import sys
970-
971969
cores = kwargs.pop('cores')
972970
chain = kwargs.pop('chain')
973971
rseed = kwargs.pop('random_seed')
@@ -978,15 +976,21 @@ def _mp_sample(**kwargs):
978976
step = kwargs.pop('step')
979977
progressbar = kwargs.pop('progressbar')
980978
use_mmap = kwargs.pop('use_mmap')
979+
model = kwargs.pop('model', None)
980+
trace = kwargs.pop('trace', None)
981981

982982
if sys.version_info.major >= 3:
983983
import pymc3.parallel_sampling as ps
984984

985-
model = modelcontext(kwargs.pop('model', None))
986-
trace = kwargs.pop('trace', None)
985+
# We did draws += tune in pm.sample
986+
draws -= tune
987+
987988
traces = []
988989
for idx in range(chain, chain + chains):
989-
strace = _choose_backend(trace, idx, model=model)
990+
if trace is not None:
991+
strace = _choose_backend(copy(trace), idx, model=model)
992+
else:
993+
strace = _choose_backend(None, idx, model=model)
990994
# TODO what is this for?
991995
update_start_vals(start[idx - chain], model.test_point, model)
992996
if step.generates_stats and strace.supports_sampler_stats:
@@ -997,20 +1001,27 @@ def _mp_sample(**kwargs):
9971001

9981002
sampler = ps.ParallelSampler(
9991003
draws, tune, chains, cores, rseed, start, step, chain, progressbar)
1000-
with sampler:
1001-
for draw in sampler:
1002-
trace = traces[draw.chain - chain]
1003-
if trace.supports_sampler_stats and draw.stats is not None:
1004-
trace.record(draw.point, draw.stats)
1005-
else:
1006-
trace.record(draw.point)
1007-
if draw.is_last:
1008-
trace.close()
1009-
return MultiTrace(traces)
1004+
try:
1005+
with sampler:
1006+
for draw in sampler:
1007+
trace = traces[draw.chain - chain]
1008+
if trace.supports_sampler_stats and draw.stats is not None:
1009+
trace.record(draw.point, draw.stats)
1010+
else:
1011+
trace.record(draw.point)
1012+
if draw.is_last:
1013+
trace.close()
1014+
return MultiTrace(traces)
1015+
except KeyboardInterrupt:
1016+
traces, length = _choose_chains(traces, tune)
1017+
return MultiTrace(traces)[:length]
1018+
finally:
1019+
for trace in traces:
1020+
trace.close()
10101021

10111022
else:
10121023
chain_nums = list(range(chain, chain + chains))
1013-
pbars = [kwargs.pop('progressbar')] + [False] * (chains - 1)
1024+
pbars = [progressbar] + [False] * (chains - 1)
10141025
jobs = (delayed(_sample)(*args, **kwargs)
10151026
for args in zip(chain_nums, pbars, rseed, start))
10161027
if use_mmap:
@@ -1020,6 +1031,35 @@ def _mp_sample(**kwargs):
10201031
return MultiTrace(traces)
10211032

10221033

1034+
def _choose_chains(traces, tune):
1035+
if tune is None:
1036+
tune = 0
1037+
1038+
if not traces:
1039+
return []
1040+
1041+
lengths = [max(0, len(trace) - tune) for trace in traces]
1042+
if not sum(lengths):
1043+
raise ValueError('Not enough samples to build a trace.')
1044+
1045+
idxs = np.argsort(lengths)[::-1]
1046+
l_sort = np.array(lengths)[idxs]
1047+
1048+
final_length = l_sort[0]
1049+
last_total = 0
1050+
for i, length in enumerate(l_sort):
1051+
total = (i + 1) * length
1052+
if total < last_total:
1053+
use_until = i
1054+
break
1055+
last_total = total
1056+
final_length = length
1057+
else:
1058+
use_until = len(lengths)
1059+
1060+
return [traces[idx] for idx in idxs[:use_until]], final_length + tune
1061+
1062+
10231063
def stop_tuning(step):
10241064
""" stop tuning the current step method """
10251065

0 commit comments

Comments
 (0)