Skip to content

Commit 60a6314

Browse files
authored
Add time remaining column to progress bars (#7273)
* Add time remaining column to progress bars * Consistent order remaining/elapsed * Disable sample_posterior_predictive taskbar when progressbar=False * Formatting * More formatting * More formatting (why doesnt pre-commit fix this?) * Disable progress bar when progress=False * Set refresh flag in progress bar updates * Typo
1 parent 6761c0c commit 60a6314

File tree

8 files changed

+32
-11
lines changed

8 files changed

+32
-11
lines changed

pymc/backends/arviz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def apply_function_over_dataset(
659659
out_dict = _DefaultTrace(n_pts)
660660
indices = range(n_pts)
661661

662-
with Progress(console=Console(theme=progressbar_theme)) as progress:
662+
with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress:
663663
task = progress.add_task("Computing ...", total=n_pts, visible=progressbar)
664664
for idx in indices:
665665
out = fn(posterior_pts[idx])

pymc/sampling/forward.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,9 @@ def sample_posterior_predictive(
829829
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
830830
ppc_trace_t = _DefaultTrace(samples)
831831
try:
832-
with Progress(console=Console(theme=progressbar_theme)) as progress:
832+
with Progress(
833+
console=Console(theme=progressbar_theme), disable=not progressbar
834+
) as progress:
833835
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
834836
for idx in np.arange(samples):
835837
if nchain > 1:

pymc/sampling/mcmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,8 +1041,8 @@ def _sample(
10411041
for it, diverging in enumerate(sampling_gen):
10421042
if it >= skip_first and diverging:
10431043
_pbar_data["divergences"] += 1
1044-
progress.update(task, advance=1)
1045-
progress.update(task, advance=1, completed=True)
1044+
progress.update(task, refresh=True, advance=1)
1045+
progress.update(task, refresh=True, advance=1, completed=True)
10461046
except KeyboardInterrupt:
10471047
pass
10481048

pymc/sampling/parallel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import numpy as np
2828

2929
from rich.console import Console
30-
from rich.progress import BarColumn, Progress, TimeRemainingColumn
30+
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
3131
from rich.theme import Theme
3232

3333
from pymc.blocking import DictToArrayBijection
@@ -428,7 +428,10 @@ def __init__(
428428
BarColumn(),
429429
"[progress.percentage]{task.percentage:>3.0f}%",
430430
TimeRemainingColumn(),
431+
TextColumn("/"),
432+
TimeElapsedColumn(),
431433
console=Console(theme=progressbar_theme),
434+
disable=not progressbar,
432435
)
433436
self._show_progress = progressbar
434437
self._divergences = 0
@@ -465,6 +468,7 @@ def __iter__(self):
465468
self._divergences += 1
466469
progress.update(
467470
task,
471+
refresh=True,
468472
completed=self._completed_draws,
469473
total=self._total_draws,
470474
description=self._desc.format(self),

pymc/sampling/population.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import cloudpickle
2525
import numpy as np
2626

27-
from rich.progress import BarColumn, Progress, TimeRemainingColumn
27+
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
2828

2929
from pymc.backends.base import BaseTrace
3030
from pymc.initial_point import PointType
@@ -104,7 +104,7 @@ def _sample_population(
104104
task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar)
105105

106106
for _ in sampling:
107-
progress.update(task, advance=1)
107+
progress.update(task, advance=1, refresh=True)
108108

109109
return
110110

@@ -180,6 +180,8 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True):
180180
BarColumn(),
181181
"[progress.percentage]{task.percentage:>3.0f}%",
182182
TimeRemainingColumn(),
183+
TextColumn("/"),
184+
TimeElapsedColumn(),
183185
) as self._progress:
184186
for c, stepper in enumerate(steppers):
185187
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)

pymc/smc/sampling.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
import numpy as np
2626

2727
from arviz import InferenceData
28-
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
28+
from rich.progress import (
29+
Progress,
30+
SpinnerColumn,
31+
TextColumn,
32+
TimeElapsedColumn,
33+
TimeRemainingColumn,
34+
)
2935

3036
import pymc
3137

@@ -366,6 +372,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
366372
with Progress(
367373
TextColumn("{task.description}"),
368374
SpinnerColumn(),
375+
TimeRemainingColumn(),
376+
TextColumn("/"),
369377
TimeElapsedColumn(),
370378
TextColumn("{task.fields[status]}"),
371379
) as progress:
@@ -403,6 +411,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
403411
stage = update_data["stage"]
404412
beta = update_data["beta"]
405413
# update the progress bar for this task:
406-
progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id)
414+
progress.update(
415+
status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id, refresh=True
416+
)
407417

408418
return tuple(cloudpickle.loads(r.result()) for r in futures)

pymc/tuning/starting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def find_MAP(
178178
if isinstance(e, StopIteration):
179179
pm._log.info(e)
180180
finally:
181-
cost_func.progress.update(cost_func.task, completed=cost_func.n_eval)
181+
cost_func.progress.update(cost_func.task, completed=cost_func.n_eval, refresh=True)
182182
print(file=sys.stdout)
183183

184184
mx0 = RaveledVars(mx0, x0.point_map_info)
@@ -223,6 +223,7 @@ def __init__(
223223
*Progress.get_default_columns(),
224224
TextColumn("{task.fields[loss]}"),
225225
console=Console(theme=progressbar_theme),
226+
disable=not progressbar,
226227
)
227228
self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="")
228229

pymc/variational/inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ def fit(
166166
def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks):
167167
i = 0
168168
try:
169-
with Progress(console=Console(theme=progressbar_theme)) as progress:
169+
with Progress(
170+
console=Console(theme=progressbar_theme), disable=not progressbar
171+
) as progress:
170172
task = progress.add_task("Fitting", total=n, visible=progressbar)
171173
for i in range(n):
172174
step_func()

0 commit comments

Comments
 (0)