Skip to content

Commit f0b5d28

Browse files
committed
Fix tests for extra divergence info
1 parent 44215ca commit f0b5d28

File tree

4 files changed

+92
-66
lines changed

4 files changed

+92
-66
lines changed

RELEASE-NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# Release Notes
22

33
## PyMC3 3.9.x (on deck)
4+
5+
### Maintenance
6+
- Fix an error on Windows and Mac where error message from unpickling models did not show up in the notebook, or where sampling froze when a worker process crashed (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991)).
7+
48
### Documentation
59
- Notebook on [multilevel modeling](https://docs.pymc.io/notebooks/multilevel_modeling.html) has been rewritten to showcase ArviZ and xarray usage for inference result analysis (see [#3963](https://github.com/pymc-devs/pymc3/pull/3963))
610

711
### New features
12+
- Introduce optional arguments to `pm.sample`: `mp_ctx` to control how the processes for parallel sampling are started, and `pickle_backend` to specify which library is used to pickle models in parallel sampling when the multiprocessing cnotext is not of type `fork`. (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991))
813
- Add sampler stats `process_time_diff`, `perf_counter_diff` and `perf_counter_start`, that record wall and CPU times for each NUTS and HMC sample (see [ #3986](https://github.com/pymc-devs/pymc3/pull/3986)).
914

1015
## PyMC3 3.9.2 (24 June 2020)

pymc3/parallel_sampling.py

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from collections import namedtuple
2222
import traceback
2323
from pymc3.exceptions import SamplingError
24-
import errno
2524

2625
import numpy as np
2726
from fastprogress.fastprogress import progress_bar
@@ -31,37 +30,6 @@
3130
logger = logging.getLogger("pymc3")
3231

3332

34-
def _get_broken_pipe_exception():
35-
import sys
36-
37-
if sys.platform == "win32":
38-
return RuntimeError(
39-
"The communication pipe between the main process "
40-
"and its spawned children is broken.\n"
41-
"In Windows OS, this usually means that the child "
42-
"process raised an exception while it was being "
43-
"spawned, before it was setup to communicate to "
44-
"the main process.\n"
45-
"The exceptions raised by the child process while "
46-
"spawning cannot be caught or handled from the "
47-
"main process, and when running from an IPython or "
48-
"jupyter notebook interactive kernel, the child's "
49-
"exception and traceback appears to be lost.\n"
50-
"A known way to see the child's error, and try to "
51-
"fix or handle it, is to run the problematic code "
52-
"as a batch script from a system's Command Prompt. "
53-
"The child's exception will be printed to the "
54-
"Command Promt's stderr, and it should be visible "
55-
"above this error and traceback.\n"
56-
"Note that if running a jupyter notebook that was "
57-
"invoked from a Command Prompt, the child's "
58-
"exception should have been printed to the Command "
59-
"Prompt on which the notebook is running."
60-
)
61-
else:
62-
return None
63-
64-
6533
class ParallelSamplingError(Exception):
6634
def __init__(self, message, chain, warnings=None):
6735
super().__init__(message)
@@ -133,18 +101,37 @@ def __init__(
133101
self._tune = tune
134102
self._pickle_backend = pickle_backend
135103

104+
def _unpickle_step_method(self):
105+
unpickle_error = (
106+
"The model could not be unpickled. This is required for sampling "
107+
"with more than one core and multiprocessing context spawn "
108+
"or forkserver."
109+
)
110+
if self._step_method_is_pickled:
111+
if self._pickle_backend == 'pickle':
112+
try:
113+
self._step_method = pickle.loads(self._step_method)
114+
except Exception:
115+
raise ValueError(unpickle_error)
116+
elif self._pickle_backend == 'dill':
117+
try:
118+
import dill
119+
except ImportError:
120+
raise ValueError(
121+
"dill must be installed for pickle_backend='dill'."
122+
)
123+
try:
124+
self._step_method = dill.loads(self._step_method)
125+
except Exception:
126+
raise ValueError(unpickle_error)
127+
else:
128+
raise ValueError("Unknown pickle backend")
129+
136130
def run(self):
137131
try:
138132
# We do not create this in __init__, as pickling this
139133
# would destroy the shared memory.
140-
if self._step_method_is_pickled:
141-
if self._pickle_backend == 'pickle':
142-
self._step_method = pickle.loads(self._step_method)
143-
elif self._pickle_backend == 'dill':
144-
import dill
145-
self._step_method = dill.loads(self._step_method)
146-
else:
147-
raise ValueError("Unknown pickle backend")
134+
self._unpickle_step_method()
148135
self._point = self._make_numpy_refs()
149136
self._start_loop()
150137
except KeyboardInterrupt:
@@ -289,7 +276,7 @@ def __init__(
289276

290277
self._process = mp_ctx.Process(
291278
daemon=True,
292-
name=name,
279+
name=process_name,
293280
target=_run_process,
294281
args=(
295282
process_name,
@@ -303,21 +290,10 @@ def __init__(
303290
pickle_backend,
304291
)
305292
)
306-
try:
307-
self._process.start()
308-
# Close the remote pipe, so that we get notified if the other
309-
# end is closed.
310-
remote_conn.close()
311-
except IOError as e:
312-
# Something may have gone wrong during the fork / spawn
313-
if e.errno == errno.EPIPE:
314-
exc = _get_broken_pipe_exception()
315-
if exc is not None:
316-
# Sleep a little to give the child process time to flush
317-
# all its error message
318-
time.sleep(0.2)
319-
raise exc
320-
raise
293+
self._process.start()
294+
# Close the remote pipe, so that we get notified if the other
295+
# end is closed.
296+
remote_conn.close()
321297

322298
@property
323299
def shared_point_view(self):
@@ -451,7 +427,12 @@ def __init__(
451427
if pickle_backend == 'pickle':
452428
step_method_pickled = pickle.dumps(step_method, protocol=-1)
453429
elif pickle_backend == 'dill':
454-
import dill
430+
try:
431+
import dill
432+
except ImportError:
433+
raise ValueError(
434+
"dill must be installed for pickle_backend='dill'."
435+
)
455436
step_method_pickled = dill.dumps(step_method, protocol=-1)
456437

457438
self._samplers = [

pymc3/sampling.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,13 @@ def sample(
338338
Defaults to `False`, but we'll switch to `True` in an upcoming release.
339339
idata_kwargs : dict, optional
340340
Keyword arguments for `arviz.from_pymc3`
341-
mp_ctx : str
342-
The name of a multiprocessing context. One of `fork`, `spawn` or `forkserver`.
343-
See multiprocessing documentation for details.
341+
mp_ctx : multiprocessing.context.BaseContent
342+
A multiprocessing context for parallel sampling. See multiprocessing
343+
documentation for details.
344+
pickle_backend : str
345+
One of `'pickle'` or `'dill'`. The library used to pickle models
346+
in parallel sampling if the multiprocessing context is not of type
347+
`fork`.
344348
345349
Returns
346350
-------
@@ -508,8 +512,10 @@ def sample(
508512
"cores": cores,
509513
"callback": callback,
510514
"discard_tuned_samples": discard_tuned_samples,
511-
"mp_ctx": mp_ctx,
515+
}
516+
parallel_args = {
512517
"pickle_backend": pickle_backend,
518+
"mp_ctx": mp_ctx,
513519
}
514520

515521
sample_args.update(kwargs)
@@ -527,7 +533,7 @@ def sample(
527533
_log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores))
528534
_print_step_hierarchy(step)
529535
try:
530-
trace = _mp_sample(**sample_args)
536+
trace = _mp_sample(**sample_args, **parallel_args)
531537
except pickle.PickleError:
532538
_log.warning("Could not pickle model, sampling singlethreaded.")
533539
_log.debug("Pickling error:", exec_info=True)

pymc3/tests/test_parallel_sampling.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,41 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import multiprocessing
1415

16+
import pytest
1517
import pymc3.parallel_sampling as ps
1618
import pymc3 as pm
1719

1820

21+
def test_context():
22+
with pm.Model():
23+
pm.Normal('x')
24+
ctx = multiprocessing.get_context('spawn')
25+
pm.sample(tune=2, draws=2, chains=2, cores=2, mp_ctx=ctx)
26+
27+
28+
class NoUnpickle:
29+
def __getstate__(self):
30+
return self.__dict__.copy()
31+
32+
def __setstate__(self, state):
33+
raise AttributeError("This fails")
34+
35+
36+
def test_bad_unpickle():
37+
with pm.Model() as model:
38+
pm.Normal('x')
39+
40+
with model:
41+
step = pm.NUTS()
42+
step.no_unpickle = NoUnpickle()
43+
with pytest.raises(Exception) as exc_info:
44+
pm.sample(tune=2, draws=2, mp_ctx='spawn', step=step,
45+
cores=2, chains=2, compute_convergence_checks=False)
46+
assert 'could not be unpickled' in str(exc_info.getrepr(style='short'))
47+
48+
1949
def test_abort():
2050
with pm.Model() as model:
2151
a = pm.Normal('a', shape=1)
@@ -25,8 +55,10 @@ def test_abort():
2555

2656
step = pm.CompoundStep([step1, step2])
2757

28-
proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1,
29-
start={'a': 1., 'b_log__': 2.})
58+
ctx = multiprocessing.get_context()
59+
proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, mp_ctx=ctx,
60+
start={'a': 1., 'b_log__': 2.},
61+
step_method_pickled=None, pickle_backend='pickle')
3062
proc.start()
3163
proc.write_next()
3264
proc.abort()
@@ -42,8 +74,10 @@ def test_explicit_sample():
4274

4375
step = pm.CompoundStep([step1, step2])
4476

45-
proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1,
46-
start={'a': 1., 'b_log__': 2.})
77+
ctx = multiprocessing.get_context()
78+
proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, mp_ctx=ctx,
79+
start={'a': 1., 'b_log__': 2.},
80+
step_method_pickled=None, pickle_backend='pickle')
4781
proc.start()
4882
while True:
4983
proc.write_next()

0 commit comments

Comments
 (0)