Skip to content

Commit 983c8de

Browse files
ColCarrolltwiecki
authored andcommitted
Add live divergence statistics (#3547)
* Add live divergence statistics * Overly enthusiastic on the refactor
1 parent 9e0f260 commit 983c8de

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

pymc3/parallel_sampling.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,11 +368,14 @@ def __init__(
368368
self._start_chain_num = start_chain_num
369369

370370
self._progress = None
371+
self._divergences = 0
372+
self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences"
373+
self._chains = chains
371374
if progressbar:
372375
self._progress = tqdm(
373376
total=chains * (draws + tune),
374377
unit="draws",
375-
desc="Sampling %s chains" % chains,
378+
desc=self._desc.format(self)
376379
)
377380

378381
def _make_active(self):
@@ -391,6 +394,9 @@ def __iter__(self):
391394
draw = ProcessAdapter.recv_draw(self._active)
392395
proc, is_last, draw, tuning, stats, warns = draw
393396
if self._progress is not None:
397+
if not tuning and stats and stats[0].get('diverging'):
398+
self._divergences += 1
399+
self._progress.set_description(self._desc.format(self))
394400
self._progress.update()
395401

396402
if is_last:

pymc3/sampling.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .exceptions import IncorrectArgumentsError
2525
from pymc3.step_methods.hmc import quadpotential
2626
import pymc3 as pm
27-
from tqdm import tqdm
27+
from tqdm import tqdm
2828

2929

3030
import sys
@@ -539,13 +539,19 @@ def _sample(chain, progressbar, random_seed, start, draws=None, step=None,
539539

540540
sampling = _iter_sample(draws, step, start, trace, chain,
541541
tune, model, random_seed)
542+
_pbar_data = None
542543
if progressbar:
543-
sampling = tqdm(sampling, total=draws)
544+
_pbar_data = {"chain": chain, "divergences": 0}
545+
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
546+
sampling = tqdm(sampling, total=draws, desc=_desc.format(**_pbar_data))
544547
try:
545548
strace = None
546-
for it, strace in enumerate(sampling):
549+
for it, (strace, diverging) in enumerate(sampling):
547550
if it >= skip_first:
548551
trace = MultiTrace([strace])
552+
if diverging and _pbar_data is not None:
553+
_pbar_data["divergences"] += 1
554+
sampling.set_description(_desc.format(**_pbar_data))
549555
except KeyboardInterrupt:
550556
pass
551557
finally:
@@ -591,7 +597,7 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
591597
"""
592598
sampling = _iter_sample(draws, step, start, trace, chain, tune,
593599
model, random_seed)
594-
for i, strace in enumerate(sampling):
600+
for i, (strace, _) in enumerate(sampling):
595601
yield MultiTrace([strace[:i + 1]])
596602

597603

@@ -632,15 +638,17 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
632638
if i == tune:
633639
step = stop_tuning(step)
634640
if step.generates_stats:
635-
point, states = step.step(point)
641+
point, stats = step.step(point)
636642
if strace.supports_sampler_stats:
637-
strace.record(point, states)
643+
strace.record(point, stats)
644+
diverging = i > tune and stats and stats[0].get('diverging')
638645
else:
639646
strace.record(point)
640647
else:
641648
point = step.step(point)
642649
strace.record(point)
643-
yield strace
650+
diverging = False
651+
yield strace, diverging
644652
except KeyboardInterrupt:
645653
strace.close()
646654
if hasattr(step, 'warnings'):
@@ -892,9 +900,9 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
892900
# apply the update to the points and record to the traces
893901
for c, strace in enumerate(traces):
894902
if steppers[c].generates_stats:
895-
points[c], states = updates[c]
903+
points[c], stats = updates[c]
896904
if strace.supports_sampler_stats:
897-
strace.record(points[c], states)
905+
strace.record(points[c], stats)
898906
else:
899907
strace.record(points[c])
900908
else:

0 commit comments

Comments
 (0)