Skip to content

Commit 83ecf8d

Browse files
committed
Update progressbar managers with existing fit results from ZarrTrace
1 parent 5b69bb8 commit 83ecf8d

File tree

4 files changed

+35
-7
lines changed

4 files changed

+35
-7
lines changed

pymc/sampling/mcmc.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,13 +1157,25 @@ def _sample_many(
11571157

11581158
with progress_manager:
11591159
for i in range(chains):
1160+
trace = traces[i]
1161+
progress_manager.set_initial_state(
1162+
*trace.completed_draws_and_divergences(chain_specific=True)
1163+
)
1164+
progress_manager._progress.update(
1165+
progress_manager.tasks[i],
1166+
draws=progress_manager.completed_draws
1167+
if progress_manager.combined_progress
1168+
else progress_manager.draws,
1169+
divergences=progress_manager.divergences,
1170+
refresh=True,
1171+
)
11601172
step.sampling_state = initial_step_state
11611173
_sample(
11621174
draws=draws,
11631175
chain=i,
11641176
start=start[i],
11651177
step=step,
1166-
trace=traces[i],
1178+
trace=trace,
11671179
rng=rngs[i],
11681180
callback=callback,
11691181
progress_manager=progress_manager,

pymc/sampling/parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,9 @@ def __init__(
512512
progressbar=progressbar,
513513
progressbar_theme=progressbar_theme,
514514
)
515+
if traces is not None:
516+
for trace in traces:
517+
self._progress.set_initial_state(*trace.completed_draws_and_divergences())
515518

516519
def _make_active(self):
517520
while self._inactive and len(self._active) < self._max_active:

pymc/sampling/population.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
2828

2929
from pymc.backends.base import BaseTrace
30-
from pymc.backends.zarr import ZarrChain
3130
from pymc.initial_point import PointType
3231
from pymc.model import Model, modelcontext
3332
from pymc.stats.convergence import log_warning_stats
@@ -110,6 +109,10 @@ def _sample_population(
110109

111110
with CustomProgress(disable=not progressbar) as progress:
112111
task = progress.add_task("[red]Sampling...", total=draws)
112+
for trace in traces:
113+
progress.update(
114+
task, completed=trace.completed_draws_and_divergences(chain_specific=True)[0]
115+
)
113116
for _ in sampling:
114117
progress.update(task)
115118

@@ -197,6 +200,7 @@ def __init__(
197200
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
198201
# ):
199202
task = self._progress.add_task(description=f"Chain {c}")
203+
self._progress.update(task, completed=first_draw_idx)
200204
secondary_end, primary_end = multiprocessing.Pipe()
201205
stepper_dumps = cloudpickle.dumps(stepper, protocol=4)
202206
process = multiprocessing.Process(

pymc/util.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ def __init__(
812812

813813
self._show_progress = show_progress
814814
self.divergences = 0
815+
self.draws = 0
815816
self.completed_draws = 0
816817
self.total_draws = draws + tune
817818
self.desc = "Sampling chain"
@@ -827,27 +828,35 @@ def __enter__(self):
827828
def __exit__(self, exc_type, exc_val, exc_tb):
828829
return self._progress.__exit__(exc_type, exc_val, exc_tb)
829830

831+
def set_initial_state(self, draws: int = 0, divergences: int = 0):
832+
self.draws = draws
833+
self.completed_draws += draws
834+
self.divergences += divergences
835+
830836
def _initialize_tasks(self):
831837
if self.combined_progress:
832838
self.tasks = [
833839
self._progress.add_task(
834840
self.desc.format(self),
835-
completed=0,
836-
draws=0,
841+
completed=self.completed_draws,
842+
draws=self.completed_draws,
837843
total=self.total_draws * self.chains - 1,
838844
chain_idx=0,
839845
sampling_speed=0,
840846
speed_unit="draws/s",
841-
**{stat: value[0] for stat, value in self.progress_stats.items()},
847+
**{
848+
stat: value[0] if stat != "diverging" else self.divergences
849+
for stat, value in self.progress_stats.items()
850+
},
842851
)
843852
]
844853

845854
else:
846855
self.tasks = [
847856
self._progress.add_task(
848857
self.desc.format(self),
849-
completed=0,
850-
draws=0,
858+
completed=self.completed_draws,
859+
draws=self.draws,
851860
total=self.total_draws - 1,
852861
chain_idx=chain_idx,
853862
sampling_speed=0,

0 commit comments

Comments
 (0)