Skip to content

Commit 9898e00

Browse files
authored
Merge pull request #3011 from aseyboldt/multiproc
Rewrite parallel sampling using multiprocessing
2 parents 9e7495b + ae1025b commit 9898e00

File tree

10 files changed

+539
-35
lines changed

10 files changed

+539
-35
lines changed

RELEASE-NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
- Improve error message `NaN occurred in optimization.` during ADVI
1414
- Save and load traces without `pickle` using `pm.save_trace` and `pm.load_trace`
1515
- Add `Kumaraswamy` distribution
16+
- Rewrite parallel sampling of multiple chains on py3. This resolves
17+
long standing issues when tranferring large traces to the main process,
18+
avoids pickleing issues on UNIX, and allows us to show a progress bar
19+
for all chains. If parallel sampling is interrupted, we now return
20+
partial results.
1621

1722
### Fixes
1823

pymc3/backends/text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,9 @@ def record(self, point):
9999
self._fh.write(','.join(columns) + '\n')
100100

101101
def close(self):
102-
self._fh.close()
103-
self._fh = None # Avoid serialization issue.
102+
if self._fh is not None:
103+
self._fh.close()
104+
self._fh = None # Avoid serialization issue.
104105

105106
# Selection methods
106107

pymc3/parallel_sampling.py

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
import multiprocessing
2+
import multiprocessing.sharedctypes
3+
import ctypes
4+
import time
5+
import logging
6+
from collections import namedtuple
7+
import traceback
8+
9+
import six
10+
import numpy as np
11+
12+
from . import theanof
13+
14+
logger = logging.getLogger('pymc3')
15+
16+
17+
# Taken from https://hg.python.org/cpython/rev/c4f92b597074
18+
class RemoteTraceback(Exception):
19+
def __init__(self, tb):
20+
self.tb = tb
21+
22+
def __str__(self):
23+
return self.tb
24+
25+
26+
class ExceptionWithTraceback:
27+
def __init__(self, exc, tb):
28+
tb = traceback.format_exception(type(exc), exc, tb)
29+
tb = ''.join(tb)
30+
self.exc = exc
31+
self.tb = '\n"""\n%s"""' % tb
32+
33+
def __reduce__(self):
34+
return rebuild_exc, (self.exc, self.tb)
35+
36+
37+
def rebuild_exc(exc, tb):
38+
exc.__cause__ = RemoteTraceback(tb)
39+
return exc
40+
41+
42+
# Messages
43+
# ('writing_done', is_last, sample_idx, tuning, stats)
44+
# ('error', *exception_info)
45+
46+
# ('abort', reason)
47+
# ('write_next',)
48+
# ('start',)
49+
50+
51+
class _Process(multiprocessing.Process):
52+
"""Seperate process for each chain.
53+
54+
We communicate with the main process using a pipe,
55+
and send finished samples using shared memory.
56+
"""
57+
def __init__(self, name, msg_pipe, step_method, shared_point,
58+
draws, tune, seed):
59+
super(_Process, self).__init__(daemon=True, name=name)
60+
self._msg_pipe = msg_pipe
61+
self._step_method = step_method
62+
self._shared_point = shared_point
63+
self._seed = seed
64+
self._tt_seed = seed + 1
65+
self._draws = draws
66+
self._tune = tune
67+
68+
def run(self):
69+
try:
70+
# We do not create this in __init__, as pickling this
71+
# would destroy the shared memory.
72+
self._point = self._make_numpy_refs()
73+
self._start_loop()
74+
except KeyboardInterrupt:
75+
pass
76+
except BaseException as e:
77+
e = ExceptionWithTraceback(e, e.__traceback__)
78+
self._msg_pipe.send(('error', e))
79+
finally:
80+
self._msg_pipe.close()
81+
82+
def _make_numpy_refs(self):
83+
shape_dtypes = self._step_method.vars_shape_dtype
84+
point = {}
85+
for name, (shape, dtype) in shape_dtypes.items():
86+
array = self._shared_point[name]
87+
self._shared_point[name] = array
88+
point[name] = np.frombuffer(array, dtype).reshape(shape)
89+
return point
90+
91+
def _write_point(self, point):
92+
for name, vals in point.items():
93+
self._point[name][...] = vals
94+
95+
def _recv_msg(self):
96+
return self._msg_pipe.recv()
97+
98+
def _start_loop(self):
99+
np.random.seed(self._seed)
100+
theanof.set_tt_rng(self._tt_seed)
101+
102+
draw = 0
103+
tuning = True
104+
105+
msg = self._recv_msg()
106+
if msg[0] == 'abort':
107+
raise KeyboardInterrupt()
108+
if msg[0] != 'start':
109+
raise ValueError('Unexpected msg ' + msg[0])
110+
111+
while True:
112+
if draw < self._draws + self._tune:
113+
point, stats = self._compute_point()
114+
else:
115+
return
116+
117+
if draw == self._tune:
118+
self._step_method.stop_tuning()
119+
tuning = False
120+
121+
msg = self._recv_msg()
122+
if msg[0] == 'abort':
123+
raise KeyboardInterrupt()
124+
elif msg[0] == 'write_next':
125+
self._write_point(point)
126+
is_last = draw + 1 == self._draws + self._tune
127+
if is_last:
128+
warns = self._collect_warnings()
129+
else:
130+
warns = None
131+
self._msg_pipe.send(
132+
('writing_done', is_last, draw, tuning, stats, warns))
133+
draw += 1
134+
else:
135+
raise ValueError('Unknown message ' + msg[0])
136+
137+
def _compute_point(self):
138+
if self._step_method.generates_stats:
139+
point, stats = self._step_method.step(self._point)
140+
else:
141+
point = self._step_method.step(self._point)
142+
stats = None
143+
return point, stats
144+
145+
def _collect_warnings(self):
146+
if hasattr(self._step_method, 'warnings'):
147+
return self._step_method.warnings()
148+
else:
149+
return []
150+
151+
152+
class ProcessAdapter(object):
153+
"""Control a Chain process from the main thread."""
154+
def __init__(self, draws, tune, step_method, chain, seed, start):
155+
self.chain = chain
156+
process_name = "worker_chain_%s" % chain
157+
self._msg_pipe, remote_conn = multiprocessing.Pipe()
158+
159+
self._shared_point = {}
160+
self._point = {}
161+
for name, (shape, dtype) in step_method.vars_shape_dtype.items():
162+
size = 1
163+
for dim in shape:
164+
size *= int(dim)
165+
size *= dtype.itemsize
166+
if size != ctypes.c_size_t(size).value:
167+
raise ValueError('Variable %s is too large' % name)
168+
169+
array = multiprocessing.sharedctypes.RawArray('c', size)
170+
self._shared_point[name] = array
171+
array_np = np.frombuffer(array, dtype).reshape(shape)
172+
array_np[...] = start[name]
173+
self._point[name] = array_np
174+
175+
self._readable = True
176+
self._num_samples = 0
177+
178+
self._process = _Process(
179+
process_name, remote_conn, step_method, self._shared_point,
180+
draws, tune, seed)
181+
# We fork right away, so that the main process can start tqdm threads
182+
self._process.start()
183+
184+
@property
185+
def shared_point_view(self):
186+
"""May only be written to or read between a `recv_draw`
187+
call from the process and a `write_next` or `abort` call.
188+
"""
189+
if not self._readable:
190+
raise RuntimeError()
191+
return self._point
192+
193+
def start(self):
194+
self._msg_pipe.send(('start',))
195+
196+
def write_next(self):
197+
self._readable = False
198+
self._msg_pipe.send(('write_next',))
199+
200+
def abort(self):
201+
self._msg_pipe.send(('abort',))
202+
203+
def join(self, timeout=None):
204+
self._process.join(timeout)
205+
206+
def terminate(self):
207+
self._process.terminate()
208+
209+
@staticmethod
210+
def recv_draw(processes, timeout=3600):
211+
if not processes:
212+
raise ValueError('No processes.')
213+
pipes = [proc._msg_pipe for proc in processes]
214+
ready = multiprocessing.connection.wait(pipes)
215+
if not ready:
216+
raise multiprocessing.TimeoutError('No message from samplers.')
217+
idxs = {id(proc._msg_pipe): proc for proc in processes}
218+
proc = idxs[id(ready[0])]
219+
msg = ready[0].recv()
220+
221+
if msg[0] == 'error':
222+
old = msg[1]
223+
six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
224+
elif msg[0] == 'writing_done':
225+
proc._readable = True
226+
proc._num_samples += 1
227+
return (proc,) + msg[1:]
228+
else:
229+
raise ValueError('Sampler sent bad message.')
230+
231+
@staticmethod
232+
def terminate_all(processes, patience=2):
233+
for process in processes:
234+
try:
235+
process.abort()
236+
except EOFError:
237+
pass
238+
239+
start_time = time.time()
240+
try:
241+
for process in processes:
242+
timeout = time.time() + patience - start_time
243+
if timeout < 0:
244+
raise multiprocessing.TimeoutError()
245+
process.join(timeout)
246+
except multiprocessing.TimeoutError:
247+
logger.warn('Chain processes did not terminate as expected. '
248+
'Terminating forcefully...')
249+
for process in processes:
250+
process.terminate()
251+
for process in processes:
252+
process.join()
253+
254+
255+
Draw = namedtuple(
256+
'Draw',
257+
['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point', 'warnings']
258+
)
259+
260+
261+
class ParallelSampler(object):
262+
def __init__(self, draws, tune, chains, cores, seeds, start_points,
263+
step_method, start_chain_num=0, progressbar=True):
264+
if progressbar:
265+
import tqdm
266+
tqdm_ = tqdm.tqdm
267+
268+
self._samplers = [
269+
ProcessAdapter(draws, tune, step_method,
270+
chain + start_chain_num, seed, start)
271+
for chain, seed, start in zip(range(chains), seeds, start_points)
272+
]
273+
274+
self._inactive = self._samplers.copy()
275+
self._finished = []
276+
self._active = []
277+
self._max_active = cores
278+
279+
self._in_context = False
280+
self._start_chain_num = start_chain_num
281+
282+
self._progress = None
283+
if progressbar:
284+
self._progress = tqdm_(
285+
total=chains * (draws + tune), unit='draws',
286+
desc='Sampling %s chains' % chains)
287+
288+
def _make_active(self):
289+
while self._inactive and len(self._active) < self._max_active:
290+
proc = self._inactive.pop(0)
291+
proc.start()
292+
proc.write_next()
293+
self._active.append(proc)
294+
295+
def __iter__(self):
296+
if not self._in_context:
297+
raise ValueError('Use ParallelSampler as context manager.')
298+
self._make_active()
299+
300+
while self._active:
301+
draw = ProcessAdapter.recv_draw(self._active)
302+
proc, is_last, draw, tuning, stats, warns = draw
303+
if self._progress is not None:
304+
self._progress.update()
305+
306+
if is_last:
307+
proc.join()
308+
self._active.remove(proc)
309+
self._finished.append(proc)
310+
self._make_active()
311+
312+
# We could also yield proc.shared_point_view directly,
313+
# and only call proc.write_next() after the yield returns.
314+
# This seems to be faster overally though, as the worker
315+
# loses less time waiting.
316+
point = {name: val.copy()
317+
for name, val in proc.shared_point_view.items()}
318+
319+
# Already called for new proc in _make_active
320+
if not is_last:
321+
proc.write_next()
322+
323+
yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns)
324+
325+
def __enter__(self):
326+
self._in_context = True
327+
return self
328+
329+
def __exit__(self, *args):
330+
ProcessAdapter.terminate_all(self._samplers)
331+
if self._progress is not None:
332+
self._progress.close()

0 commit comments

Comments
 (0)