Skip to content

Commit 6ccc11b

Browse files
authored
Improvements in error messages and logging for bad initial energy. (#3238)
* Improvements in error messages and logging for bad initial energy. Co-authored-by: springcoil <[email protected]> Co-authored-by: aseyboldt <[email protected]> * Adding another test one for parallel and non-parallel * Adding random seeds to tests * Removed unnecessary files and formatting in test_sampling * Changed chains to cores in one test * Applying black formatting * Change ParallelSampling to Sampling * Adding back in ParallelSampling * AAAGH remove assertion * Skipping the parallel test for py27 * Forgot to import sys
1 parent 3cb4570 commit 6ccc11b

File tree

8 files changed

+987
-333
lines changed

8 files changed

+987
-333
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- Add Incomplete Beta function `incomplete_beta(a, b, value)`
1010
- Add log CDF functions to continuous distributions: `Beta`, `Cauchy`, `ExGaussian`, `Exponential`, `Flat`, `Gumbel`, `HalfCauchy`, `HalfFlat`, `HalfNormal`, `Laplace`, `Logistic`, `Lognormal`, `Normal`, `Pareto`, `StudentT`, `Triangular`, `Uniform`, `Wald`, `Weibull`.
1111
- Behavior of `sample_posterior_predictive` is now to produce posterior predictive samples, in order, from all values of the `trace`. Previously, by default it would produce 1 chain worth of samples, using a random selection from the `trace` (#3212)
12+
- Show diagnostics for initial energy errors in HMC and NUTS.
1213

1314
### Maintenance
1415

pymc3/backends/report.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class WarningType(enum.Enum):
1919
# Indications that chains did not converge, eg Rhat
2020
CONVERGENCE = 6
2121
BAD_ACCEPTANCE = 7
22+
BAD_ENERGY = 8
2223

2324

2425
SamplerWarning = namedtuple(

pymc3/parallel_sampling.py

Lines changed: 85 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,23 @@
55
import logging
66
from collections import namedtuple
77
import traceback
8+
from pymc3.exceptions import SamplingError
89

910
import six
1011
import numpy as np
1112

1213
from . import theanof
1314

14-
logger = logging.getLogger('pymc3')
15+
logger = logging.getLogger("pymc3")
16+
17+
18+
class ParallelSamplingError(Exception):
19+
def __init__(self, message, chain, warnings=None):
20+
super(ParallelSamplingError, self).__init__(message)
21+
if warnings is None:
22+
warnings = []
23+
self._chain = chain
24+
self._warnings = warnings
1525

1626

1727
# Taken from https://hg.python.org/cpython/rev/c4f92b597074
@@ -26,7 +36,7 @@ def __str__(self):
2636
class ExceptionWithTraceback:
2737
def __init__(self, exc, tb):
2838
tb = traceback.format_exception(type(exc), exc, tb)
29-
tb = ''.join(tb)
39+
tb = "".join(tb)
3040
self.exc = exc
3141
self.tb = '\n"""\n%s"""' % tb
3242

@@ -40,8 +50,8 @@ def rebuild_exc(exc, tb):
4050

4151

4252
# Messages
43-
# ('writing_done', is_last, sample_idx, tuning, stats)
44-
# ('error', *exception_info)
53+
# ('writing_done', is_last, sample_idx, tuning, stats, warns)
54+
# ('error', warnings, *exception_info)
4555

4656
# ('abort', reason)
4757
# ('write_next',)
@@ -50,12 +60,11 @@ def rebuild_exc(exc, tb):
5060

5161
class _Process(multiprocessing.Process):
5262
"""Seperate process for each chain.
53-
5463
We communicate with the main process using a pipe,
5564
and send finished samples using shared memory.
5665
"""
57-
def __init__(self, name, msg_pipe, step_method, shared_point,
58-
draws, tune, seed):
66+
67+
def __init__(self, name, msg_pipe, step_method, shared_point, draws, tune, seed):
5968
super(_Process, self).__init__(daemon=True, name=name)
6069
self._msg_pipe = msg_pipe
6170
self._step_method = step_method
@@ -75,7 +84,7 @@ def run(self):
7584
pass
7685
except BaseException as e:
7786
e = ExceptionWithTraceback(e, e.__traceback__)
78-
self._msg_pipe.send(('error', e))
87+
self._msg_pipe.send(("error", None, e))
7988
finally:
8089
self._msg_pipe.close()
8190

@@ -103,14 +112,19 @@ def _start_loop(self):
103112
tuning = True
104113

105114
msg = self._recv_msg()
106-
if msg[0] == 'abort':
115+
if msg[0] == "abort":
107116
raise KeyboardInterrupt()
108-
if msg[0] != 'start':
109-
raise ValueError('Unexpected msg ' + msg[0])
117+
if msg[0] != "start":
118+
raise ValueError("Unexpected msg " + msg[0])
110119

111120
while True:
112121
if draw < self._draws + self._tune:
113-
point, stats = self._compute_point()
122+
try:
123+
point, stats = self._compute_point()
124+
except SamplingError as e:
125+
warns = self._collect_warnings()
126+
e = ExceptionWithTraceback(e, e.__traceback__)
127+
self._msg_pipe.send(("error", warns, e))
114128
else:
115129
return
116130

@@ -119,20 +133,21 @@ def _start_loop(self):
119133
tuning = False
120134

121135
msg = self._recv_msg()
122-
if msg[0] == 'abort':
136+
if msg[0] == "abort":
123137
raise KeyboardInterrupt()
124-
elif msg[0] == 'write_next':
138+
elif msg[0] == "write_next":
125139
self._write_point(point)
126140
is_last = draw + 1 == self._draws + self._tune
127141
if is_last:
128142
warns = self._collect_warnings()
129143
else:
130144
warns = None
131145
self._msg_pipe.send(
132-
('writing_done', is_last, draw, tuning, stats, warns))
146+
("writing_done", is_last, draw, tuning, stats, warns)
147+
)
133148
draw += 1
134149
else:
135-
raise ValueError('Unknown message ' + msg[0])
150+
raise ValueError("Unknown message " + msg[0])
136151

137152
def _compute_point(self):
138153
if self._step_method.generates_stats:
@@ -143,14 +158,15 @@ def _compute_point(self):
143158
return point, stats
144159

145160
def _collect_warnings(self):
146-
if hasattr(self._step_method, 'warnings'):
161+
if hasattr(self._step_method, "warnings"):
147162
return self._step_method.warnings()
148163
else:
149164
return []
150165

151166

152167
class ProcessAdapter(object):
153168
"""Control a Chain process from the main thread."""
169+
154170
def __init__(self, draws, tune, step_method, chain, seed, start):
155171
self.chain = chain
156172
process_name = "worker_chain_%s" % chain
@@ -164,9 +180,9 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
164180
size *= int(dim)
165181
size *= dtype.itemsize
166182
if size != ctypes.c_size_t(size).value:
167-
raise ValueError('Variable %s is too large' % name)
183+
raise ValueError("Variable %s is too large" % name)
168184

169-
array = multiprocessing.sharedctypes.RawArray('c', size)
185+
array = multiprocessing.sharedctypes.RawArray("c", size)
170186
self._shared_point[name] = array
171187
array_np = np.frombuffer(array, dtype).reshape(shape)
172188
array_np[...] = start[name]
@@ -176,8 +192,14 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
176192
self._num_samples = 0
177193

178194
self._process = _Process(
179-
process_name, remote_conn, step_method, self._shared_point,
180-
draws, tune, seed)
195+
process_name,
196+
remote_conn,
197+
step_method,
198+
self._shared_point,
199+
draws,
200+
tune,
201+
seed,
202+
)
181203
# We fork right away, so that the main process can start tqdm threads
182204
self._process.start()
183205

@@ -191,14 +213,14 @@ def shared_point_view(self):
191213
return self._point
192214

193215
def start(self):
194-
self._msg_pipe.send(('start',))
216+
self._msg_pipe.send(("start",))
195217

196218
def write_next(self):
197219
self._readable = False
198-
self._msg_pipe.send(('write_next',))
220+
self._msg_pipe.send(("write_next",))
199221

200222
def abort(self):
201-
self._msg_pipe.send(('abort',))
223+
self._msg_pipe.send(("abort",))
202224

203225
def join(self, timeout=None):
204226
self._process.join(timeout)
@@ -209,24 +231,28 @@ def terminate(self):
209231
@staticmethod
210232
def recv_draw(processes, timeout=3600):
211233
if not processes:
212-
raise ValueError('No processes.')
234+
raise ValueError("No processes.")
213235
pipes = [proc._msg_pipe for proc in processes]
214236
ready = multiprocessing.connection.wait(pipes)
215237
if not ready:
216-
raise multiprocessing.TimeoutError('No message from samplers.')
238+
raise multiprocessing.TimeoutError("No message from samplers.")
217239
idxs = {id(proc._msg_pipe): proc for proc in processes}
218240
proc = idxs[id(ready[0])]
219241
msg = ready[0].recv()
220242

221-
if msg[0] == 'error':
222-
old = msg[1]
223-
six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
224-
elif msg[0] == 'writing_done':
243+
if msg[0] == "error":
244+
warns, old_error = msg[1:]
245+
if warns is not None:
246+
error = ParallelSamplingError(str(old_error), proc.chain, warns)
247+
else:
248+
error = RuntimeError("Chain %s failed." % proc.chain)
249+
six.raise_from(error, old_error)
250+
elif msg[0] == "writing_done":
225251
proc._readable = True
226252
proc._num_samples += 1
227253
return (proc,) + msg[1:]
228254
else:
229-
raise ValueError('Sampler sent bad message.')
255+
raise ValueError("Sampler sent bad message.")
230256

231257
@staticmethod
232258
def terminate_all(processes, patience=2):
@@ -244,34 +270,46 @@ def terminate_all(processes, patience=2):
244270
raise multiprocessing.TimeoutError()
245271
process.join(timeout)
246272
except multiprocessing.TimeoutError:
247-
logger.warn('Chain processes did not terminate as expected. '
248-
'Terminating forcefully...')
273+
logger.warn(
274+
"Chain processes did not terminate as expected. "
275+
"Terminating forcefully..."
276+
)
249277
for process in processes:
250278
process.terminate()
251279
for process in processes:
252280
process.join()
253281

254282

255283
Draw = namedtuple(
256-
'Draw',
257-
['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point', 'warnings']
284+
"Draw", ["chain", "is_last", "draw_idx", "tuning", "stats", "point", "warnings"]
258285
)
259286

260287

261288
class ParallelSampler(object):
262-
def __init__(self, draws, tune, chains, cores, seeds, start_points,
263-
step_method, start_chain_num=0, progressbar=True):
289+
def __init__(
290+
self,
291+
draws,
292+
tune,
293+
chains,
294+
cores,
295+
seeds,
296+
start_points,
297+
step_method,
298+
start_chain_num=0,
299+
progressbar=True,
300+
):
264301
if progressbar:
265302
import tqdm
303+
266304
tqdm_ = tqdm.tqdm
267305

268306
if any(len(arg) != chains for arg in [seeds, start_points]):
269-
raise ValueError(
270-
'Number of seeds and start_points must be %s.' % chains)
307+
raise ValueError("Number of seeds and start_points must be %s." % chains)
271308

272309
self._samplers = [
273-
ProcessAdapter(draws, tune, step_method,
274-
chain + start_chain_num, seed, start)
310+
ProcessAdapter(
311+
draws, tune, step_method, chain + start_chain_num, seed, start
312+
)
275313
for chain, seed, start in zip(range(chains), seeds, start_points)
276314
]
277315

@@ -286,8 +324,10 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points,
286324
self._progress = None
287325
if progressbar:
288326
self._progress = tqdm_(
289-
total=chains * (draws + tune), unit='draws',
290-
desc='Sampling %s chains' % chains)
327+
total=chains * (draws + tune),
328+
unit="draws",
329+
desc="Sampling %s chains" % chains,
330+
)
291331

292332
def _make_active(self):
293333
while self._inactive and len(self._active) < self._max_active:
@@ -298,7 +338,7 @@ def _make_active(self):
298338

299339
def __iter__(self):
300340
if not self._in_context:
301-
raise ValueError('Use ParallelSampler as context manager.')
341+
raise ValueError("Use ParallelSampler as context manager.")
302342
self._make_active()
303343

304344
while self._active:
@@ -317,8 +357,7 @@ def __iter__(self):
317357
# and only call proc.write_next() after the yield returns.
318358
# This seems to be faster overally though, as the worker
319359
# loses less time waiting.
320-
point = {name: val.copy()
321-
for name, val in proc.shared_point_view.items()}
360+
point = {name: val.copy() for name, val in proc.shared_point_view.items()}
322361

323362
# Already called for new proc in _make_active
324363
if not is_last:

pymc3/sampling.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -986,17 +986,28 @@ def _mp_sample(draws, tune, step, chains, cores, chain, random_seed,
986986
draws, tune, chains, cores, random_seed, start, step,
987987
chain, progressbar)
988988
try:
989-
with sampler:
990-
for draw in sampler:
991-
trace = traces[draw.chain - chain]
992-
if trace.supports_sampler_stats and draw.stats is not None:
993-
trace.record(draw.point, draw.stats)
994-
else:
995-
trace.record(draw.point)
996-
if draw.is_last:
997-
trace.close()
998-
if draw.warnings is not None:
999-
trace._add_warnings(draw.warnings)
989+
try:
990+
with sampler:
991+
for draw in sampler:
992+
trace = traces[draw.chain - chain]
993+
if (trace.supports_sampler_stats
994+
and draw.stats is not None):
995+
trace.record(draw.point, draw.stats)
996+
else:
997+
trace.record(draw.point)
998+
if draw.is_last:
999+
trace.close()
1000+
if draw.warnings is not None:
1001+
trace._add_warnings(draw.warnings)
1002+
except ps.ParallelSamplingError as error:
1003+
trace = traces[error._chain - chain]
1004+
trace._add_warnings(error._warnings)
1005+
for trace in traces:
1006+
trace.close()
1007+
1008+
multitrace = MultiTrace(traces)
1009+
multitrace._report._log_summary()
1010+
raise
10001011
return MultiTrace(traces)
10011012
except KeyboardInterrupt:
10021013
traces, length = _choose_chains(traces, tune)
@@ -1512,4 +1523,4 @@ def init_nuts(init='auto', chains=1, n_init=500000, model=None,
15121523

15131524
step = pm.NUTS(potential=potential, model=model, **kwargs)
15141525

1515-
return start, step
1526+
return start, step

0 commit comments

Comments
 (0)