From 46226a83ae401d3e954976dd915bc460efcfa691 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 5 May 2025 10:50:47 +0200 Subject: [PATCH] Alternative fix attempt of progressbar with nested compound step samplers --- pymc/step_methods/compound.py | 23 ++++++++---------- pymc/step_methods/hmc/nuts.py | 16 ++++--------- pymc/step_methods/metropolis.py | 20 +++++++--------- pymc/step_methods/slicer.py | 15 ++++-------- pymc/util.py | 41 +++++++++++++++++++++++++-------- 5 files changed, 56 insertions(+), 59 deletions(-) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d07b070f0f..637c04318e 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -189,11 +189,11 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - return stats + def _make_update_stats_functions(): + def update_stats(step_stats, chain_idx): + return step_stats - return update_stats + return (update_stats,) # Hack for creating the class correctly when unpickling. def __getnewargs_ex__(self): @@ -332,16 +332,11 @@ def _progressbar_config(self, n_chains=1): return columns, stats - def _make_update_stats_function(self): - update_fns = [method._make_update_stats_function() for method in self.methods] - - def update_stats(stats, step_stats, chain_idx): - for step_stat, update_fn in zip(step_stats, update_fns): - stats = update_fn(stats, step_stat, chain_idx) - - return stats - - return update_stats + def _make_update_stats_functions(self): + update_functions = [] + for method in self.methods: + update_functions.extend(method._make_update_stats_functions()) + return update_functions def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 18707c3592..334a4eac36 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -248,19 +248,11 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_update_stats_functions(): + def update_stats(stats): + return {key: stats[key] for key in ("diverging", "step_size", "tree_size")} - if not step_stats["tune"]: - stats["divergences"][chain_idx] += step_stats["diverging"] - - stats["step_size"][chain_idx] = step_stats["step_size"] - stats["tree_size"][chain_idx] = step_stats["tree_size"] - return stats - - return update_stats + return (update_stats,) # A proposal for the next position diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653d..4d798e9470 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -346,18 +346,14 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] - - stats["tune"][chain_idx] = step_stats["tune"] - stats["accept_rate"][chain_idx] = step_stats["accept"] - stats["scaling"][chain_idx] = step_stats["scaling"] - - return stats - - return update_stats + def _make_update_stats_functions(): + def update_stats(step_stats): + return { + "accept_rate" if key == "accept" else key: step_stats[key] + for key in ("tune", "accept", "scaling") + } + + return (update_stats,) def tune(scale, acc_rate): diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 9c10acfdf4..ef5bbebc4c 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -212,15 +212,8 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_update_stats_functions(): + def update_stats(step_stats): + return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} - stats["tune"][chain_idx] = step_stats["tune"] - stats["nstep_out"][chain_idx] = step_stats["nstep_out"] - stats["nstep_in"][chain_idx] = step_stats["nstep_in"] - - return stats - - return update_stats + return (update_stats,) diff --git a/pymc/util.py b/pymc/util.py index 979b3beebf..bfe50f3507 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -806,9 +806,8 @@ def __init__( progressbar=progressbar, progressbar_theme=progressbar_theme, ) - self.progress_stats = progress_stats - self.update_stats = step_method._make_update_stats_function() + self.update_stats_functions = step_method._make_update_stats_functions() self._show_progress = show_progress self.divergences = 0 @@ -883,12 +882,34 @@ def update(self, chain_idx, is_last, draw, tuning, stats): if not tuning and stats and stats[0].get("diverging"): self.divergences += 1 - self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx) - more_updates = ( - {stat: value[chain_idx] for stat, value in self.progress_stats.items()} - if self.full_stats - else {} - ) + if self.full_stats: + # TODO: Index by chain already? + chain_progress_stats = [ + update_states_fn(step_stats) + for update_states_fn, step_stats in zip( + self.update_stats_functions, stats, strict=True + ) + ] + all_step_stats = {} + for step_stats in chain_progress_stats: + for key, val in step_stats.items(): + if key in all_step_stats: + continue + count = ( + sum(step_key.startswith(f"{key}_") for step_key in all_step_stats) + 1 + ) + all_step_stats[f"{key}_{count}"] = val + else: + all_step_stats[key] = val + + else: + all_step_stats = {} + + # more_updates = ( + # {stat: value[chain_idx] for stat, value in progress_stats.items()} + # if self.full_stats + # else {} + # ) self._progress.update( self.tasks[chain_idx], @@ -896,14 +917,14 @@ def update(self, chain_idx, is_last, draw, tuning, stats): draws=draw, sampling_speed=speed, speed_unit=unit, - **more_updates, + **all_step_stats, ) if is_last: self._progress.update( self.tasks[chain_idx], draws=draw + 1 if not self.combined_progress else draw, - **more_updates, + **all_step_stats, refresh=True, )