Skip to content

Commit bc3aea6

Browse files
aseyboldtricardoV94
authored andcommitted
Avoid repeated status polling in smc
1 parent 8a68a5c commit bc3aea6

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

pymc/smc/sampling.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import warnings
1919

2020
from collections import defaultdict
21-
from concurrent.futures import ProcessPoolExecutor
21+
from concurrent.futures import ProcessPoolExecutor, wait
2222
from typing import Any
2323

2424
import cloudpickle
@@ -404,13 +404,19 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
404404
)
405405

406406
# monitor the progress:
407-
while sum([future.done() for future in futures]) < len(futures):
407+
done = []
408+
remaining = futures
409+
while len(remaining) > 0:
410+
finished, remaining = wait(remaining, timeout=0.1)
411+
done.extend(finished)
408412
for task_id, update_data in _progress.items():
409413
stage = update_data["stage"]
410414
beta = update_data["beta"]
411415
# update the progress bar for this task:
412416
progress.update(
413-
status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id, refresh=True
417+
status=f"Stage: {stage} Beta: {beta:.3f}",
418+
task_id=task_id,
419+
refresh=True,
414420
)
415421

416-
return tuple(cloudpickle.loads(r.result()) for r in futures)
422+
return tuple(cloudpickle.loads(r.result()) for r in done)

0 commit comments

Comments
 (0)