Skip to content

Commit 212ff07

Browse files
committed
Add warnings in new multiprocessing sampling
1 parent 03b10de commit 212ff07

File tree

5 files changed

+31
-17
lines changed

5 files changed

+31
-17
lines changed

pymc3/parallel_sampling.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,12 @@ def _start_loop(self):
9898
elif msg[0] == 'write_next':
9999
self._write_point(point)
100100
is_last = draw + 1 == self._draws + self._tune
101+
if is_last:
102+
warns = self._collect_warnings()
103+
else:
104+
warns = None
101105
self._msg_pipe.send(
102-
('writing_done', is_last, draw, tuning, stats))
106+
('writing_done', is_last, draw, tuning, stats, warns))
103107
draw += 1
104108
else:
105109
raise ValueError('Unknown message ' + msg[0])
@@ -112,6 +116,12 @@ def _compute_point(self):
112116
stats = None
113117
return point, stats
114118

119+
def _collect_warnings(self):
120+
if hasattr(self._step_method, 'warnings'):
121+
return self._step_method.warnings()
122+
else:
123+
return []
124+
115125

116126
class ProcessAdapter(object):
117127
"""Control a Chain process from the main thread."""
@@ -218,7 +228,7 @@ def terminate_all(processes, patience=2):
218228

219229
Draw = namedtuple(
220230
'Draw',
221-
['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point']
231+
['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point', 'warnings']
222232
)
223233

224234

@@ -274,7 +284,7 @@ def __iter__(self):
274284

275285
while self._active:
276286
draw = ProcessAdapter.recv_draw(self._active)
277-
proc, is_last, draw, tuning, stats = draw
287+
proc, is_last, draw, tuning, stats, warns = draw
278288
if self._progress is not None:
279289
self._progress[proc.chain - self._start_chain_num].update()
280290

@@ -299,7 +309,7 @@ def __iter__(self):
299309
if not is_last:
300310
proc.write_next()
301311

302-
yield Draw(proc.chain, is_last, draw, tuning, stats, point)
312+
yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns)
303313

304314
def __enter__(self):
305315
self._in_context = True

pymc3/sampling.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
663663
except KeyboardInterrupt:
664664
strace.close()
665665
if hasattr(step, 'warnings'):
666-
warns = step.warnings(strace)
666+
warns = step.warnings()
667667
strace._add_warnings(warns)
668668
raise
669669
except BaseException:
@@ -672,7 +672,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
672672
else:
673673
strace.close()
674674
if hasattr(step, 'warnings'):
675-
warns = step.warnings(strace)
675+
warns = step.warnings()
676676
strace._add_warnings(warns)
677677

678678

@@ -1002,6 +1002,8 @@ def _mp_sample(draws, tune, step, chains, cores, chain, random_seed,
10021002
trace.record(draw.point)
10031003
if draw.is_last:
10041004
trace.close()
1005+
if draw.warnings is not None:
1006+
trace._add_warnings(draw.warnings)
10051007
return MultiTrace(traces)
10061008
except KeyboardInterrupt:
10071009
traces, length = _choose_chains(traces, tune)
@@ -1013,12 +1015,14 @@ def _mp_sample(draws, tune, step, chains, cores, chain, random_seed,
10131015
else:
10141016
chain_nums = list(range(chain, chain + chains))
10151017
pbars = [progressbar] + [False] * (chains - 1)
1016-
jobs = (delayed(_sample)(
1017-
chain=args[0], progressbar=args[1], random_seed=args[2],
1018-
start=args[3], draws=draws, step=step, trace=trace,
1019-
tune=tune, model=model, **kwargs
1020-
)
1021-
for args in zip(chain_nums, pbars, random_seed, start))
1018+
jobs = (
1019+
delayed(_sample)(
1020+
chain=args[0], progressbar=args[1], random_seed=args[2],
1021+
start=args[3], draws=draws, step=step, trace=trace,
1022+
tune=tune, model=model, **kwargs
1023+
)
1024+
for args in zip(chain_nums, pbars, random_seed, start)
1025+
)
10221026
if use_mmap:
10231027
traces = Parallel(n_jobs=cores)(jobs)
10241028
else:

pymc3/step_methods/compound.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def step(self, point):
3333
point = method.step(point)
3434
return point
3535

36-
def warnings(self, strace):
36+
def warnings(self):
3737
warns = []
3838
for method in self.methods:
3939
if hasattr(method, 'warnings'):
40-
warns.extend(method.warnings(strace))
40+
warns.extend(method.warnings())
4141
return warns
4242

4343
def stop_tuning(self):

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def reset(self, start=None):
164164
self.tune = True
165165
self.potential.reset()
166166

167-
def warnings(self, strace):
167+
def warnings(self):
168168
# list.copy() is not available in python2
169169
warnings = self._warnings[:]
170170

pymc3/step_methods/hmc/nuts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def competence(var, has_grad):
184184
return Competence.IDEAL
185185
return Competence.INCOMPATIBLE
186186

187-
def warnings(self, strace):
188-
warnings = super(NUTS, self).warnings(strace)
187+
def warnings(self):
188+
warnings = super(NUTS, self).warnings()
189189
n_samples = self._samples_after_tune
190190
n_treedepth = self._reached_max_treedepth
191191

0 commit comments

Comments
 (0)