Skip to content

Commit 85658d7

Browse files
committed
Rewrite of multiprocessing code
1 parent f361844 commit 85658d7

File tree

5 files changed

+434
-18
lines changed

5 files changed

+434
-18
lines changed

pymc3/parallel_sampling.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
import multiprocessing
2+
import multiprocessing.sharedctypes
3+
import sys
4+
import ctypes
5+
import time
6+
import logging
7+
from collections import namedtuple
8+
9+
import six
10+
import tqdm
11+
import numpy as np
12+
13+
from . import theanof
14+
15+
logger = logging.getLogger('pymc3')
16+
17+
# Messages
18+
# ('writing_done', is_last, sample_idx, tuning, stats)
19+
# ('error', *exception_info)
20+
21+
# ('abort', reason)
22+
# ('write_next',)
23+
# ('start',)
24+
25+
26+
class _Process(multiprocessing.Process):
27+
"""Seperate process for each chain.
28+
29+
We communicate with the main process using a pipe,
30+
and send finished samples using shared memory.
31+
"""
32+
def __init__(self, msg_pipe, step_method, shared_point, draws, tune, seed):
33+
super(_Process, self).__init__(daemon=True)
34+
self._msg_pipe = msg_pipe
35+
self._step_method = step_method
36+
self._shared_point = shared_point
37+
self._seed = seed
38+
self._tt_seed = seed + 1
39+
self._draws = draws
40+
self._tune = tune
41+
42+
def run(self):
43+
try:
44+
# We do not create this in __init__, as pickling this
45+
# would destroy the shared memory.
46+
self._point = self._make_numpy_refs()
47+
self._start_loop()
48+
except KeyboardInterrupt:
49+
pass
50+
except BaseException:
51+
exc_info = sys.exc_info()
52+
self._msg_pipe.send(('error', exc_info[:2]))
53+
finally:
54+
self._msg_pipe.close()
55+
56+
def _make_numpy_refs(self):
57+
shape_dtypes = self._step_method.vars_shape_dtype
58+
point = {}
59+
for name, (shape, dtype) in shape_dtypes.items():
60+
array = self._shared_point[name]
61+
self._shared_point[name] = array
62+
point[name] = np.frombuffer(array, dtype).reshape(shape)
63+
return point
64+
65+
def _write_point(self, point):
66+
for name, vals in point.items():
67+
self._point[name][...] = vals
68+
69+
def _recv_msg(self):
70+
return self._msg_pipe.recv()
71+
72+
def _start_loop(self):
73+
np.random.seed(self._seed)
74+
theanof.set_tt_rng(self._tt_seed)
75+
76+
draw = 0
77+
tuning = True
78+
79+
msg = self._recv_msg()
80+
if msg[0] == 'abort':
81+
raise KeyboardInterrupt()
82+
if msg[0] != 'start':
83+
raise ValueError('Unexpected msg ' + msg[0])
84+
85+
while True:
86+
if draw < self._draws + self._tune:
87+
point, stats = self._compute_point()
88+
else:
89+
return
90+
91+
if draw == self._tune:
92+
self._step_method.stop_tuning()
93+
tuning = False
94+
95+
msg = self._recv_msg()
96+
if msg[0] == 'abort':
97+
raise KeyboardInterrupt()
98+
elif msg[0] == 'write_next':
99+
self._write_point(point)
100+
is_last = draw + 1 == self._draws + self._tune
101+
self._msg_pipe.send(
102+
('writing_done', is_last, draw, tuning, stats))
103+
draw += 1
104+
else:
105+
raise ValueError('Unknown message ' + msg[0])
106+
107+
def _compute_point(self):
108+
if self._step_method.generates_stats:
109+
point, stats = self._step_method.step(self._point)
110+
else:
111+
point = self._step_method.step(self._point)
112+
stats = None
113+
return point, stats
114+
115+
116+
class ProcessAdapter(object):
117+
"""Control a Chain process from the main thread."""
118+
def __init__(self, draws, tune, step_method, chain, seed, start):
119+
self.chain = chain
120+
self._msg_pipe, remote_conn = multiprocessing.Pipe()
121+
122+
self._shared_point = {}
123+
self._point = {}
124+
for name, (shape, dtype) in step_method.vars_shape_dtype.items():
125+
size = 1
126+
for dim in shape:
127+
size *= int(dim)
128+
size *= dtype.itemsize
129+
if size != ctypes.c_size_t(size).value:
130+
raise ValueError('Variable %s is too large' % name)
131+
132+
array = multiprocessing.sharedctypes.RawArray('c', size)
133+
self._shared_point[name] = array
134+
array_np = np.frombuffer(array, dtype).reshape(shape)
135+
array_np[...] = start[name]
136+
self._point[name] = array_np
137+
138+
self._readable = True
139+
self._num_samples = 0
140+
141+
self._process = _Process(
142+
remote_conn, step_method, self._shared_point, draws, tune, seed)
143+
# We fork right away, so that the main process can start tqdm threads
144+
self._process.start()
145+
146+
@property
147+
def shared_point_view(self):
148+
"""May only be written to or read between a `recv_draw`
149+
call from the process and a `write_next` or `abort` call.
150+
"""
151+
if not self._readable:
152+
raise RuntimeError()
153+
return self._point
154+
155+
def start(self):
156+
self._msg_pipe.send(('start',))
157+
158+
def write_next(self):
159+
self._readable = False
160+
self._msg_pipe.send(('write_next',))
161+
162+
def abort(self):
163+
self._msg_pipe.send(('abort',))
164+
165+
def join(self, timeout=None):
166+
self._process.join(timeout)
167+
168+
def terminate(self):
169+
self._process.terminate()
170+
171+
@staticmethod
172+
def recv_draw(processes, timeout=3600):
173+
if not processes:
174+
raise ValueError('No processes.')
175+
pipes = [proc._msg_pipe for proc in processes]
176+
ready = multiprocessing.connection.wait(pipes)
177+
if not ready:
178+
raise multiprocessing.TimeoutError('No message from samplers.')
179+
idxs = {id(proc._msg_pipe): proc for proc in processes}
180+
proc = idxs[id(ready[0])]
181+
msg = ready[0].recv()
182+
183+
if msg[0] == 'error':
184+
old = msg[1][1]#.with_traceback(msg[1][2])
185+
six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
186+
elif msg[0] == 'writing_done':
187+
proc._readable = True
188+
proc._num_samples += 1
189+
return (proc, *msg[1:])
190+
else:
191+
raise ValueError('Sampler sent bad message.')
192+
193+
@staticmethod
194+
def terminate_all(processes, patience=2):
195+
for process in processes:
196+
try:
197+
process.abort()
198+
except EOFError:
199+
pass
200+
201+
start_time = time.time()
202+
try:
203+
for process in processes:
204+
timeout = start_time + patience - time.time()
205+
if timeout < 0:
206+
raise multiprocessing.TimeoutError()
207+
process.join(timeout)
208+
except multiprocessing.TimeoutError:
209+
logger.warn('Chain processes did not terminate as expected. '
210+
'Terminating forcefully...')
211+
for process in processes:
212+
process.terminate()
213+
for process in processes:
214+
process.join()
215+
216+
217+
Draw = namedtuple(
218+
'Draw',
219+
['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point']
220+
)
221+
222+
223+
class ParallelSampler(object):
224+
def __init__(self, draws, tune, chains, cores, seeds, start_points,
225+
step_method, start_chain_num=0, progressbar=True):
226+
self._samplers = [
227+
ProcessAdapter(draws, tune, step_method,
228+
chain + start_chain_num, seed, start)
229+
for chain, seed, start in zip(range(chains), seeds, start_points)
230+
]
231+
232+
self._inactive = self._samplers.copy()
233+
self._finished = []
234+
self._active = []
235+
self._max_active = cores
236+
237+
self._in_context = False
238+
self._start_chain_num = start_chain_num
239+
240+
self._global_progress = self._progress = None
241+
if progressbar:
242+
self._global_progress = tqdm.tqdm(
243+
total=chains, unit='chains', position=1)
244+
self._progress = [
245+
tqdm.tqdm(
246+
desc=' Chain %i' % (chain + start_chain_num),
247+
unit='draws',
248+
position=chain + 2,
249+
total=draws + tune)
250+
for chain in range(chains)
251+
]
252+
253+
def _make_active(self):
254+
while self._inactive and len(self._active) < self._max_active:
255+
proc = self._inactive.pop()
256+
proc.start()
257+
proc.write_next()
258+
self._active.append(proc)
259+
260+
def __iter__(self):
261+
if not self._in_context:
262+
raise ValueError('Use ParallelSampler as context manager.')
263+
self._make_active()
264+
265+
while self._active:
266+
draw = ProcessAdapter.recv_draw(self._active)
267+
proc, is_last, draw, tuning, stats = draw
268+
if self._progress is not None:
269+
self._progress[proc.chain - self._start_chain_num].update()
270+
271+
if is_last:
272+
proc.join()
273+
self._active.remove(proc)
274+
self._finished.append(proc)
275+
self._make_active()
276+
if self._global_progress is not None:
277+
self._global_progress.update()
278+
279+
yield Draw(
280+
proc.chain, is_last, draw, tuning,
281+
stats, proc.shared_point_view
282+
)
283+
284+
# Already called for new proc in _make_active
285+
if not is_last:
286+
proc.write_next()
287+
288+
def __enter__(self):
289+
self._in_context = True
290+
return self
291+
292+
def __exit__(self, *args):
293+
ProcessAdapter.terminate_all(self._samplers)
294+
if self._progress is not None:
295+
self._global_progress.close()
296+
for progress in self._progress:
297+
progress.close()

pymc3/sampling.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -966,35 +966,64 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds):
966966

967967

968968
def _mp_sample(**kwargs):
969+
import sys
970+
969971
cores = kwargs.pop('cores')
970972
chain = kwargs.pop('chain')
971973
rseed = kwargs.pop('random_seed')
972974
start = kwargs.pop('start')
973975
chains = kwargs.pop('chains')
976+
draws = kwargs.pop('draws')
977+
tune = kwargs.pop('tune')
978+
step = kwargs.pop('step')
979+
progressbar = kwargs.pop('progressbar')
974980
use_mmap = kwargs.pop('use_mmap')
975981

976-
chain_nums = list(range(chain, chain + chains))
977-
pbars = [kwargs.pop('progressbar')] + [False] * (chains - 1)
978-
jobs = (delayed(_sample)(*args, **kwargs)
979-
for args in zip(chain_nums, pbars, rseed, start))
982+
if sys.version_info.major >= 3:
983+
import pymc3.parallel_sampling as ps
984+
985+
model = modelcontext(kwargs.pop('model', None))
986+
trace = kwargs.pop('trace', None)
987+
traces = []
988+
for idx in range(chain, chain + chains):
989+
strace = _choose_backend(trace, idx, model=model)
990+
# TODO what is this for?
991+
update_start_vals(start[idx - chain], model.test_point, model)
992+
if step.generates_stats and strace.supports_sampler_stats:
993+
strace.setup(draws + tune, idx + chain, step.stats_dtypes)
994+
else:
995+
strace.setup(draws + tune, idx + chain)
996+
traces.append(strace)
997+
998+
sampler = ps.ParallelSampler(
999+
draws, tune, chains, cores, rseed, start, step, chain, progressbar)
1000+
with sampler:
1001+
for draw in sampler:
1002+
trace = traces[draw.chain - chain]
1003+
if trace.supports_sampler_stats and draw.stats is not None:
1004+
trace.record(draw.point, draw.stats)
1005+
else:
1006+
trace.record(draw.point)
1007+
if draw.is_last:
1008+
trace.close()
1009+
return MultiTrace(traces)
9801010

981-
if use_mmap:
982-
traces = Parallel(n_jobs=cores)(jobs)
9831011
else:
984-
traces = Parallel(n_jobs=cores, mmap_mode=None)(jobs)
985-
986-
return MultiTrace(traces)
1012+
chain_nums = list(range(chain, chain + chains))
1013+
pbars = [kwargs.pop('progressbar')] + [False] * (chains - 1)
1014+
jobs = (delayed(_sample)(*args, **kwargs)
1015+
for args in zip(chain_nums, pbars, rseed, start))
1016+
if use_mmap:
1017+
traces = Parallel(n_jobs=cores)(jobs)
1018+
else:
1019+
traces = Parallel(n_jobs=cores, mmap_mode=None)(jobs)
1020+
return MultiTrace(traces)
9871021

9881022

9891023
def stop_tuning(step):
9901024
""" stop tuning the current step method """
9911025

992-
if hasattr(step, 'tune'):
993-
step.tune = False
994-
995-
if hasattr(step, 'methods'):
996-
step.methods = [stop_tuning(s) for s in step.methods]
997-
1026+
step.stop_tuning()
9981027
return step
9991028

10001029

0 commit comments

Comments
 (0)