|
6 | 6 | from collections import namedtuple
|
7 | 7 | import traceback
|
8 | 8 | from pymc3.exceptions import SamplingError
|
| 9 | +import errno |
9 | 10 |
|
10 | 11 | import numpy as np
|
11 | 12 |
|
|
14 | 15 | logger = logging.getLogger("pymc3")
|
15 | 16 |
|
16 | 17 |
|
| 18 | +def _get_broken_pipe_exception(): |
| 19 | + import sys |
| 20 | + if sys.platform == 'win32': |
| 21 | + return RuntimeError("The communication pipe between the main process " |
| 22 | + "and its spawned children is broken.\n" |
| 23 | + "In Windows OS, this usually means that the child " |
| 24 | + "process raised an exception while it was being " |
| 25 | + "spawned, before it was setup to communicate to " |
| 26 | + "the main process.\n" |
| 27 | + "The exceptions raised by the child process while " |
| 28 | + "spawning cannot be caught or handled from the " |
| 29 | + "main process, and when running from an IPython or " |
| 30 | + "jupyter notebook interactive kernel, the child's " |
| 31 | + "exception and traceback appears to be lost.\n" |
| 32 | + "A known way to see the child's error, and try to " |
| 33 | + "fix or handle it, is to run the problematic code " |
| 34 | + "as a batch script from a system's Command Prompt. " |
| 35 | + "The child's exception will be printed to the " |
| 36 | + "Command Promt's stderr, and it should be visible " |
| 37 | + "above this error and traceback.\n" |
| 38 | + "Note that if running a jupyter notebook that was " |
| 39 | + "invoked from a Command Prompt, the child's " |
| 40 | + "exception should have been printed to the Command " |
| 41 | + "Prompt on which the notebook is running.") |
| 42 | + else: |
| 43 | + return None |
| 44 | + |
| 45 | + |
17 | 46 | class ParallelSamplingError(Exception):
|
18 | 47 | def __init__(self, message, chain, warnings=None):
|
19 | 48 | super().__init__(message)
|
@@ -83,10 +112,19 @@ def run(self):
|
83 | 112 | pass
|
84 | 113 | except BaseException as e:
|
85 | 114 | e = ExceptionWithTraceback(e, e.__traceback__)
|
| 115 | + # Send is not blocking so we have to force a wait for the abort |
| 116 | + # message |
86 | 117 | self._msg_pipe.send(("error", None, e))
|
| 118 | + self._wait_for_abortion() |
87 | 119 | finally:
|
88 | 120 | self._msg_pipe.close()
|
89 | 121 |
|
| 122 | + def _wait_for_abortion(self): |
| 123 | + while True: |
| 124 | + msg = self._recv_msg() |
| 125 | + if msg[0] == "abort": |
| 126 | + break |
| 127 | + |
90 | 128 | def _make_numpy_refs(self):
|
91 | 129 | shape_dtypes = self._step_method.vars_shape_dtype
|
92 | 130 | point = {}
|
@@ -200,7 +238,18 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
|
200 | 238 | seed,
|
201 | 239 | )
|
202 | 240 | # We fork right away, so that the main process can start tqdm threads
|
203 |
| - self._process.start() |
| 241 | + try: |
| 242 | + self._process.start() |
| 243 | + except IOError as e: |
| 244 | + # Something may have gone wrong during the fork / spawn |
| 245 | + if e.errno == errno.EPIPE: |
| 246 | + exc = _get_broken_pipe_exception() |
| 247 | + if exc is not None: |
| 248 | + # Sleep a little to give the child process time to flush |
| 249 | + # all its error message |
| 250 | + time.sleep(0.2) |
| 251 | + raise exc |
| 252 | + raise |
204 | 253 |
|
205 | 254 | @property
|
206 | 255 | def shared_point_view(self):
|
|
0 commit comments