5
5
import logging
6
6
from collections import namedtuple
7
7
import traceback
8
+ from pymc3 .exceptions import SamplingError
8
9
9
10
import six
10
11
import numpy as np
11
12
12
13
from . import theanof
13
14
14
- logger = logging .getLogger ('pymc3' )
15
+ logger = logging .getLogger ("pymc3" )
16
+
17
+
18
+ class ParallelSamplingError (Exception ):
19
+ def __init__ (self , message , chain , warnings = None ):
20
+ super (ParallelSamplingError , self ).__init__ (message )
21
+ if warnings is None :
22
+ warnings = []
23
+ self ._chain = chain
24
+ self ._warnings = warnings
15
25
16
26
17
27
# Taken from https://hg.python.org/cpython/rev/c4f92b597074
@@ -26,7 +36,7 @@ def __str__(self):
26
36
class ExceptionWithTraceback :
27
37
def __init__ (self , exc , tb ):
28
38
tb = traceback .format_exception (type (exc ), exc , tb )
29
- tb = '' .join (tb )
39
+ tb = "" .join (tb )
30
40
self .exc = exc
31
41
self .tb = '\n """\n %s"""' % tb
32
42
@@ -40,8 +50,8 @@ def rebuild_exc(exc, tb):
40
50
41
51
42
52
# Messages
43
- # ('writing_done', is_last, sample_idx, tuning, stats)
44
- # ('error', *exception_info)
53
+ # ('writing_done', is_last, sample_idx, tuning, stats, warns )
54
+ # ('error', warnings, *exception_info)
45
55
46
56
# ('abort', reason)
47
57
# ('write_next',)
@@ -50,12 +60,11 @@ def rebuild_exc(exc, tb):
50
60
51
61
class _Process (multiprocessing .Process ):
52
62
"""Seperate process for each chain.
53
-
54
63
We communicate with the main process using a pipe,
55
64
and send finished samples using shared memory.
56
65
"""
57
- def __init__ ( self , name , msg_pipe , step_method , shared_point ,
58
- draws , tune , seed ):
66
+
67
+ def __init__ ( self , name , msg_pipe , step_method , shared_point , draws , tune , seed ):
59
68
super (_Process , self ).__init__ (daemon = True , name = name )
60
69
self ._msg_pipe = msg_pipe
61
70
self ._step_method = step_method
@@ -75,7 +84,7 @@ def run(self):
75
84
pass
76
85
except BaseException as e :
77
86
e = ExceptionWithTraceback (e , e .__traceback__ )
78
- self ._msg_pipe .send ((' error' , e ))
87
+ self ._msg_pipe .send ((" error" , None , e ))
79
88
finally :
80
89
self ._msg_pipe .close ()
81
90
@@ -103,14 +112,19 @@ def _start_loop(self):
103
112
tuning = True
104
113
105
114
msg = self ._recv_msg ()
106
- if msg [0 ] == ' abort' :
115
+ if msg [0 ] == " abort" :
107
116
raise KeyboardInterrupt ()
108
- if msg [0 ] != ' start' :
109
- raise ValueError (' Unexpected msg ' + msg [0 ])
117
+ if msg [0 ] != " start" :
118
+ raise ValueError (" Unexpected msg " + msg [0 ])
110
119
111
120
while True :
112
121
if draw < self ._draws + self ._tune :
113
- point , stats = self ._compute_point ()
122
+ try :
123
+ point , stats = self ._compute_point ()
124
+ except SamplingError as e :
125
+ warns = self ._collect_warnings ()
126
+ e = ExceptionWithTraceback (e , e .__traceback__ )
127
+ self ._msg_pipe .send (("error" , warns , e ))
114
128
else :
115
129
return
116
130
@@ -119,20 +133,21 @@ def _start_loop(self):
119
133
tuning = False
120
134
121
135
msg = self ._recv_msg ()
122
- if msg [0 ] == ' abort' :
136
+ if msg [0 ] == " abort" :
123
137
raise KeyboardInterrupt ()
124
- elif msg [0 ] == ' write_next' :
138
+ elif msg [0 ] == " write_next" :
125
139
self ._write_point (point )
126
140
is_last = draw + 1 == self ._draws + self ._tune
127
141
if is_last :
128
142
warns = self ._collect_warnings ()
129
143
else :
130
144
warns = None
131
145
self ._msg_pipe .send (
132
- ('writing_done' , is_last , draw , tuning , stats , warns ))
146
+ ("writing_done" , is_last , draw , tuning , stats , warns )
147
+ )
133
148
draw += 1
134
149
else :
135
- raise ValueError (' Unknown message ' + msg [0 ])
150
+ raise ValueError (" Unknown message " + msg [0 ])
136
151
137
152
def _compute_point (self ):
138
153
if self ._step_method .generates_stats :
@@ -143,14 +158,15 @@ def _compute_point(self):
143
158
return point , stats
144
159
145
160
def _collect_warnings (self ):
146
- if hasattr (self ._step_method , ' warnings' ):
161
+ if hasattr (self ._step_method , " warnings" ):
147
162
return self ._step_method .warnings ()
148
163
else :
149
164
return []
150
165
151
166
152
167
class ProcessAdapter (object ):
153
168
"""Control a Chain process from the main thread."""
169
+
154
170
def __init__ (self , draws , tune , step_method , chain , seed , start ):
155
171
self .chain = chain
156
172
process_name = "worker_chain_%s" % chain
@@ -164,9 +180,9 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
164
180
size *= int (dim )
165
181
size *= dtype .itemsize
166
182
if size != ctypes .c_size_t (size ).value :
167
- raise ValueError (' Variable %s is too large' % name )
183
+ raise ValueError (" Variable %s is too large" % name )
168
184
169
- array = multiprocessing .sharedctypes .RawArray ('c' , size )
185
+ array = multiprocessing .sharedctypes .RawArray ("c" , size )
170
186
self ._shared_point [name ] = array
171
187
array_np = np .frombuffer (array , dtype ).reshape (shape )
172
188
array_np [...] = start [name ]
@@ -176,8 +192,14 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
176
192
self ._num_samples = 0
177
193
178
194
self ._process = _Process (
179
- process_name , remote_conn , step_method , self ._shared_point ,
180
- draws , tune , seed )
195
+ process_name ,
196
+ remote_conn ,
197
+ step_method ,
198
+ self ._shared_point ,
199
+ draws ,
200
+ tune ,
201
+ seed ,
202
+ )
181
203
# We fork right away, so that the main process can start tqdm threads
182
204
self ._process .start ()
183
205
@@ -191,14 +213,14 @@ def shared_point_view(self):
191
213
return self ._point
192
214
193
215
def start (self ):
194
- self ._msg_pipe .send ((' start' ,))
216
+ self ._msg_pipe .send ((" start" ,))
195
217
196
218
def write_next (self ):
197
219
self ._readable = False
198
- self ._msg_pipe .send ((' write_next' ,))
220
+ self ._msg_pipe .send ((" write_next" ,))
199
221
200
222
def abort (self ):
201
- self ._msg_pipe .send ((' abort' ,))
223
+ self ._msg_pipe .send ((" abort" ,))
202
224
203
225
def join (self , timeout = None ):
204
226
self ._process .join (timeout )
@@ -209,24 +231,28 @@ def terminate(self):
209
231
@staticmethod
210
232
def recv_draw (processes , timeout = 3600 ):
211
233
if not processes :
212
- raise ValueError (' No processes.' )
234
+ raise ValueError (" No processes." )
213
235
pipes = [proc ._msg_pipe for proc in processes ]
214
236
ready = multiprocessing .connection .wait (pipes )
215
237
if not ready :
216
- raise multiprocessing .TimeoutError (' No message from samplers.' )
238
+ raise multiprocessing .TimeoutError (" No message from samplers." )
217
239
idxs = {id (proc ._msg_pipe ): proc for proc in processes }
218
240
proc = idxs [id (ready [0 ])]
219
241
msg = ready [0 ].recv ()
220
242
221
- if msg [0 ] == 'error' :
222
- old = msg [1 ]
223
- six .raise_from (RuntimeError ('Chain %s failed.' % proc .chain ), old )
224
- elif msg [0 ] == 'writing_done' :
243
+ if msg [0 ] == "error" :
244
+ warns , old_error = msg [1 :]
245
+ if warns is not None :
246
+ error = ParallelSamplingError (str (old_error ), proc .chain , warns )
247
+ else :
248
+ error = RuntimeError ("Chain %s failed." % proc .chain )
249
+ six .raise_from (error , old_error )
250
+ elif msg [0 ] == "writing_done" :
225
251
proc ._readable = True
226
252
proc ._num_samples += 1
227
253
return (proc ,) + msg [1 :]
228
254
else :
229
- raise ValueError (' Sampler sent bad message.' )
255
+ raise ValueError (" Sampler sent bad message." )
230
256
231
257
@staticmethod
232
258
def terminate_all (processes , patience = 2 ):
@@ -244,34 +270,46 @@ def terminate_all(processes, patience=2):
244
270
raise multiprocessing .TimeoutError ()
245
271
process .join (timeout )
246
272
except multiprocessing .TimeoutError :
247
- logger .warn ('Chain processes did not terminate as expected. '
248
- 'Terminating forcefully...' )
273
+ logger .warn (
274
+ "Chain processes did not terminate as expected. "
275
+ "Terminating forcefully..."
276
+ )
249
277
for process in processes :
250
278
process .terminate ()
251
279
for process in processes :
252
280
process .join ()
253
281
254
282
255
283
Draw = namedtuple (
256
- 'Draw' ,
257
- ['chain' , 'is_last' , 'draw_idx' , 'tuning' , 'stats' , 'point' , 'warnings' ]
284
+ "Draw" , ["chain" , "is_last" , "draw_idx" , "tuning" , "stats" , "point" , "warnings" ]
258
285
)
259
286
260
287
261
288
class ParallelSampler (object ):
262
- def __init__ (self , draws , tune , chains , cores , seeds , start_points ,
263
- step_method , start_chain_num = 0 , progressbar = True ):
289
+ def __init__ (
290
+ self ,
291
+ draws ,
292
+ tune ,
293
+ chains ,
294
+ cores ,
295
+ seeds ,
296
+ start_points ,
297
+ step_method ,
298
+ start_chain_num = 0 ,
299
+ progressbar = True ,
300
+ ):
264
301
if progressbar :
265
302
import tqdm
303
+
266
304
tqdm_ = tqdm .tqdm
267
305
268
306
if any (len (arg ) != chains for arg in [seeds , start_points ]):
269
- raise ValueError (
270
- 'Number of seeds and start_points must be %s.' % chains )
307
+ raise ValueError ("Number of seeds and start_points must be %s." % chains )
271
308
272
309
self ._samplers = [
273
- ProcessAdapter (draws , tune , step_method ,
274
- chain + start_chain_num , seed , start )
310
+ ProcessAdapter (
311
+ draws , tune , step_method , chain + start_chain_num , seed , start
312
+ )
275
313
for chain , seed , start in zip (range (chains ), seeds , start_points )
276
314
]
277
315
@@ -286,8 +324,10 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points,
286
324
self ._progress = None
287
325
if progressbar :
288
326
self ._progress = tqdm_ (
289
- total = chains * (draws + tune ), unit = 'draws' ,
290
- desc = 'Sampling %s chains' % chains )
327
+ total = chains * (draws + tune ),
328
+ unit = "draws" ,
329
+ desc = "Sampling %s chains" % chains ,
330
+ )
291
331
292
332
def _make_active (self ):
293
333
while self ._inactive and len (self ._active ) < self ._max_active :
@@ -298,7 +338,7 @@ def _make_active(self):
298
338
299
339
def __iter__ (self ):
300
340
if not self ._in_context :
301
- raise ValueError (' Use ParallelSampler as context manager.' )
341
+ raise ValueError (" Use ParallelSampler as context manager." )
302
342
self ._make_active ()
303
343
304
344
while self ._active :
@@ -317,8 +357,7 @@ def __iter__(self):
317
357
# and only call proc.write_next() after the yield returns.
318
358
# This seems to be faster overally though, as the worker
319
359
# loses less time waiting.
320
- point = {name : val .copy ()
321
- for name , val in proc .shared_point_view .items ()}
360
+ point = {name : val .copy () for name , val in proc .shared_point_view .items ()}
322
361
323
362
# Already called for new proc in _make_active
324
363
if not is_last :
0 commit comments