|
24 | 24 | from .exceptions import IncorrectArgumentsError
|
25 | 25 | from pymc3.step_methods.hmc import quadpotential
|
26 | 26 | import pymc3 as pm
|
27 |
| -from tqdm import tqdm |
| 27 | +from tqdm import tqdm |
28 | 28 |
|
29 | 29 |
|
30 | 30 | import sys
|
@@ -539,13 +539,19 @@ def _sample(chain, progressbar, random_seed, start, draws=None, step=None,
|
539 | 539 |
|
540 | 540 | sampling = _iter_sample(draws, step, start, trace, chain,
|
541 | 541 | tune, model, random_seed)
|
| 542 | + _pbar_data = None |
542 | 543 | if progressbar:
|
543 |
| - sampling = tqdm(sampling, total=draws) |
| 544 | + _pbar_data = {"chain": chain, "divergences": 0} |
| 545 | + _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" |
| 546 | + sampling = tqdm(sampling, total=draws, desc=_desc.format(**_pbar_data)) |
544 | 547 | try:
|
545 | 548 | strace = None
|
546 |
| - for it, strace in enumerate(sampling): |
| 549 | + for it, (strace, diverging) in enumerate(sampling): |
547 | 550 | if it >= skip_first:
|
548 | 551 | trace = MultiTrace([strace])
|
| 552 | + if diverging and _pbar_data is not None: |
| 553 | + _pbar_data["divergences"] += 1 |
| 554 | + sampling.set_description(_desc.format(**_pbar_data)) |
549 | 555 | except KeyboardInterrupt:
|
550 | 556 | pass
|
551 | 557 | finally:
|
@@ -591,7 +597,7 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
|
591 | 597 | """
|
592 | 598 | sampling = _iter_sample(draws, step, start, trace, chain, tune,
|
593 | 599 | model, random_seed)
|
594 |
| - for i, strace in enumerate(sampling): |
| 600 | + for i, (strace, _) in enumerate(sampling): |
595 | 601 | yield MultiTrace([strace[:i + 1]])
|
596 | 602 |
|
597 | 603 |
|
@@ -632,15 +638,17 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
|
632 | 638 | if i == tune:
|
633 | 639 | step = stop_tuning(step)
|
634 | 640 | if step.generates_stats:
|
635 |
| - point, states = step.step(point) |
| 641 | + point, stats = step.step(point) |
636 | 642 | if strace.supports_sampler_stats:
|
637 |
| - strace.record(point, states) |
| 643 | + strace.record(point, stats) |
| 644 | + diverging = i > tune and stats and stats[0].get('diverging') |
638 | 645 | else:
|
639 | 646 | strace.record(point)
|
640 | 647 | else:
|
641 | 648 | point = step.step(point)
|
642 | 649 | strace.record(point)
|
643 |
| - yield strace |
| 650 | + diverging = False |
| 651 | + yield strace, diverging |
644 | 652 | except KeyboardInterrupt:
|
645 | 653 | strace.close()
|
646 | 654 | if hasattr(step, 'warnings'):
|
@@ -892,9 +900,9 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
|
892 | 900 | # apply the update to the points and record to the traces
|
893 | 901 | for c, strace in enumerate(traces):
|
894 | 902 | if steppers[c].generates_stats:
|
895 |
| - points[c], states = updates[c] |
| 903 | + points[c], stats = updates[c] |
896 | 904 | if strace.supports_sampler_stats:
|
897 |
| - strace.record(points[c], states) |
| 905 | + strace.record(points[c], stats) |
898 | 906 | else:
|
899 | 907 | strace.record(points[c])
|
900 | 908 | else:
|
|
0 commit comments