diff --git a/pymc/smc/runners.py b/pymc/smc/runners.py new file mode 100644 index 0000000000..2d45fcbf70 --- /dev/null +++ b/pymc/smc/runners.py @@ -0,0 +1,51 @@ +import multiprocessing as mp + +from itertools import repeat + +import cloudpickle + +from fastprogress.fastprogress import progress_bar + + +def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores): + pbar = progress_bar((), total=100, display=progressbar) + pbar.update(0) + pbars = [pbar] + [None] * (chains - 1) + + pool = mp.Pool(cores) + + # "manually" (de)serialize params before/after multiprocessing + params = tuple(cloudpickle.dumps(p) for p in params) + kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} + results = _starmap_with_kwargs( + pool, + to_run, + [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)], + repeat(kernel_kwargs), + ) + results = tuple(cloudpickle.loads(r) for r in results) + pool.close() + pool.join() + return results + + +def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs): + results = [] + pbar = progress_bar((), total=100 * chains, display=progressbar) + pbar.update(0) + for chain in range(chains): + pbar.offset = 100 * chain + pbar.base_comment = f"Chain: {chain + 1}/{chains}" + results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs)) + return results + + +def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): + # Helper function to allow kwargs with Pool.starmap + # Copied from https://stackoverflow.com/a/53173433/13311693 + args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) + return pool.starmap(_apply_args_and_kwargs, args_for_starmap) + + +def _apply_args_and_kwargs(fn, args, kwargs): + return fn(*args, **kwargs) diff --git a/pymc/smc/sample_smc.py b/pymc/smc/sample_smc.py index 0bc229e4d7..1a09dfe695 100644 --- a/pymc/smc/sample_smc.py +++ b/pymc/smc/sample_smc.py @@ -13,19 +13,16 @@ # limitations under the License. import logging -import multiprocessing as mp import time import warnings from collections import defaultdict from collections.abc import Iterable -from itertools import repeat import cloudpickle import numpy as np from arviz import InferenceData -from fastprogress.fastprogress import progress_bar import pymc @@ -33,6 +30,7 @@ from pymc.backends.base import MultiTrace from pymc.model import modelcontext from pymc.parallel_sampling import _cpu_count +from pymc.smc.runners import run_chains_parallel, run_chains_sequential from pymc.smc.smc import IMH @@ -222,37 +220,15 @@ def sample_smc( ) t1 = time.time() + if cores > 1: - pbar = progress_bar((), total=100, display=progressbar) - pbar.update(0) - pbars = [pbar] + [None] * (chains - 1) - - pool = mp.Pool(cores) - - # "manually" (de)serialize params before/after multiprocessing - params = tuple(cloudpickle.dumps(p) for p in params) - kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} - results = _starmap_with_kwargs( - pool, - _sample_smc_int, - [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)], - repeat(kernel_kwargs), + results = run_chains_parallel( + chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores ) - results = tuple(cloudpickle.loads(r) for r in results) - pool.close() - pool.join() - else: - results = [] - pbar = progress_bar((), total=100 * chains, display=progressbar) - pbar.update(0) - for chain in range(chains): - pbar.offset = 100 * chain - pbar.base_comment = f"Chain: {chain+1}/{chains}" - results.append( - _sample_smc_int(*params, random_seed[chain], chain, pbar, **kernel_kwargs) - ) - + results = run_chains_sequential( + chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs + ) ( traces, sample_stats, @@ -260,14 +236,38 @@ def sample_smc( ) = zip(*results) trace = MultiTrace(traces) - idata = None - # Save sample_stats _t_sampling = time.time() - t1 + sample_stats, idata = _save_sample_stats( + sample_settings, + sample_stats, + chains, + trace, + return_inferencedata, + _t_sampling, + idata_kwargs, + model, + ) + + if compute_convergence_checks: + _compute_convergence_checks(idata, draws, model, trace) + return idata if return_inferencedata else trace + + +def _save_sample_stats( + sample_settings, + sample_stats, + chains, + trace, + return_inferencedata, + _t_sampling, + idata_kwargs, + model, +): sample_settings_dict = sample_settings[0] sample_settings_dict["_t_sampling"] = _t_sampling - sample_stats_dict = sample_stats[0] + if chains > 1: # Collect the stat values from each chain in a single list for stat in sample_stats[0].keys(): @@ -281,6 +281,7 @@ def sample_smc( setattr(trace.report, stat, value) for stat, value in sample_settings_dict.items(): setattr(trace.report, stat, value) + idata = None else: for stat, value in sample_stats_dict.items(): if chains > 1: @@ -303,19 +304,20 @@ def sample_smc( idata = to_inference_data(trace, **ikwargs) idata = InferenceData(**idata, sample_stats=sample_stats) - if compute_convergence_checks: - if draws < 100: - warnings.warn( - "The number of samples is too small to check convergence reliably.", - stacklevel=2, - ) - else: - if idata is None: - idata = to_inference_data(trace, log_likelihood=False) - trace.report._run_convergence_checks(idata, model) - trace.report._log_summary() + return sample_stats, idata - return idata if return_inferencedata else trace + +def _compute_convergence_checks(idata, draws, model, trace): + if draws < 100: + warnings.warn( + "The number of samples is too small to check convergence reliably.", + stacklevel=2, + ) + else: + if idata is None: + idata = to_inference_data(trace, log_likelihood=False) + trace.report._run_convergence_checks(idata, model) + trace.report._log_summary() def _sample_smc_int( @@ -389,14 +391,3 @@ def _sample_smc_int( results = cloudpickle.dumps(results) return results - - -def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): - # Helper function to allow kwargs with Pool.starmap - # Copied from https://stackoverflow.com/a/53173433/13311693 - args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) - return pool.starmap(_apply_args_and_kwargs, args_for_starmap) - - -def _apply_args_and_kwargs(fn, args, kwargs): - return fn(*args, **kwargs) diff --git a/pymc/smc/smc.py b/pymc/smc/smc.py index 88d15ced80..91a3a46e3d 100644 --- a/pymc/smc/smc.py +++ b/pymc/smc/smc.py @@ -151,6 +151,8 @@ def __init__( self.draws = draws self.start = start + if threshold < 0 or threshold > 1: + raise ValueError(f"Threshold value {threshold} must be between 0 and 1") self.threshold = threshold self.model = model self.rng = np.random.default_rng(seed=random_seed) @@ -192,7 +194,6 @@ def _initialize_kernel(self): initial_point = self.model.recompute_initial_point(seed=self.rng.integers(2 ** 30)) for v in self.variables: self.var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size) - # Create particles bijection map if self.start: init_rnd = self.start @@ -203,6 +204,7 @@ def _initialize_kernel(self): for i in range(self.draws): point = Point({v.name: init_rnd[v.name][i] for v in self.variables}, model=self.model) population.append(DictToArrayBijection.map(point).data) + self.tempered_posterior = np.array(floatX(population)) # Initialize prior and likelihood log probabilities @@ -228,13 +230,16 @@ def setup_kernel(self): def update_beta_and_weights(self): """Calculate the next inverse temperature (beta) - The importance weights based on two sucesive tempered likelihoods (i.e. + The importance weights based on two successive tempered likelihoods (i.e. two successive values of beta) and updates the marginal likelihood estimate. + + ESS is calculated for importance sampling. BDA 3rd ed. eq 10.4 """ self.iteration += 1 low_beta = old_beta = self.beta up_beta = 2.0 + rN = int(len(self.likelihood_logp) * self.threshold) while up_beta - low_beta > 1e-6: @@ -268,6 +273,7 @@ def resample(self): self.tempered_posterior = self.tempered_posterior[self.resampling_indexes] self.prior_logp = self.prior_logp[self.resampling_indexes] self.likelihood_logp = self.likelihood_logp[self.resampling_indexes] + self.tempered_posterior_logp = self.prior_logp + self.likelihood_logp * self.beta def tune(self): @@ -303,7 +309,7 @@ def sample_settings(self) -> Dict: def _posterior_to_trace(self, chain=0) -> NDArray: """Save results into a PyMC trace - This method shoud not be overwritten. + This method should not be overwritten. """ lenght_pos = len(self.tempered_posterior) varnames = [v.name for v in self.variables] @@ -497,7 +503,6 @@ def tune(self): def mutate(self): """Metropolis-Hastings perturbation.""" ac_ = np.empty((self.n_steps, self.draws)) - log_R = np.log(self.rng.random((self.n_steps, self.draws))) for n_step in range(self.n_steps): proposal = floatX( diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index e862018ba6..ff24399473 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -42,7 +42,7 @@ def setup_class(self): super().setup_class() self.samples = 1000 n = 4 - mu1 = np.ones(n) * (1.0 / 2) + mu1 = np.ones(n) * 0.5 mu2 = -mu1 stdev = 0.1 @@ -54,6 +54,9 @@ def setup_class(self): w2 = 1 - stdev def two_gaussians(x): + """ + Mixture of gaussians likelihood + """ log_like1 = ( -0.5 * n * at.log(2 * np.pi) - 0.5 * at.log(dsigma) @@ -80,8 +83,9 @@ def test_sample(self): initial_rng_state = np.random.get_state() with self.SMC_test: mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False) - assert_random_state_equal(initial_rng_state, np.random.get_state()) - + assert_random_state_equal( + initial_rng_state, np.random.get_state() + ) # TODO: why this? maybe to verify that nothing was sampled? x = mtrace["X"] mu1d = np.abs(x).mean(axis=0) np.testing.assert_allclose(self.muref, mu1d, rtol=0.0, atol=0.03) @@ -109,7 +113,6 @@ def test_discrete_rounding_proposal(self): def test_unobserved_discrete(self): n = 10 rng = self.get_random_state() - z_true = np.zeros(n, dtype=int) z_true[int(n / 2) :] = 1 y = st.norm(np.array([-1, 1])[z_true], 0.25).rvs(random_state=rng) @@ -124,6 +127,10 @@ def test_unobserved_discrete(self): assert np.all(np.median(trace["z"], axis=0) == z_true) def test_marginal_likelihood(self): + """ + Verifies that the log marginal likelihood function + can be correctly computed for a Beta-Bernoulli model. + """ data = np.repeat([1, 0], [50, 50]) marginals = [] a_prior_0, b_prior_0 = 1.0, 1.0 @@ -135,6 +142,7 @@ def test_marginal_likelihood(self): y = pm.Bernoulli("y", a, observed=data) trace = pm.sample_smc(2000, return_inferencedata=False) marginals.append(trace.report.log_marginal_likelihood) + # compare to the analytical result assert abs(np.exp(np.nanmean(marginals[1]) - np.nanmean(marginals[0])) - 4.0) <= 1