From 095efad9a89bf2e0b6897f8c7cdc4dc2b19929ac Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 20 Dec 2021 16:47:53 -0300 Subject: [PATCH 01/10] more --- pymc/smc/smc.py | 26 +++++++++++++++++++------- pymc/tests/test_smc.py | 29 ++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/pymc/smc/smc.py b/pymc/smc/smc.py index 88d15ced80..bf7c32f568 100644 --- a/pymc/smc/smc.py +++ b/pymc/smc/smc.py @@ -151,6 +151,8 @@ def __init__( self.draws = draws self.start = start + if threshold < 0 or threshold > 1: + raise ValueError(f"Threshold value {threshold} must be between 0 and 1") self.threshold = threshold self.model = model self.rng = np.random.default_rng(seed=random_seed) @@ -192,7 +194,6 @@ def _initialize_kernel(self): initial_point = self.model.recompute_initial_point(seed=self.rng.integers(2 ** 30)) for v in self.variables: self.var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size) - # Create particles bijection map if self.start: init_rnd = self.start @@ -203,6 +204,7 @@ def _initialize_kernel(self): for i in range(self.draws): point = Point({v.name: init_rnd[v.name][i] for v in self.variables}, model=self.model) population.append(DictToArrayBijection.map(point).data) + self.tempered_posterior = np.array(floatX(population)) # Initialize prior and likelihood log probabilities @@ -228,30 +230,36 @@ def setup_kernel(self): def update_beta_and_weights(self): """Calculate the next inverse temperature (beta) - The importance weights based on two sucesive tempered likelihoods (i.e. + The importance weights based on two successive tempered likelihoods (i.e. two successive values of beta) and updates the marginal likelihood estimate. """ self.iteration += 1 low_beta = old_beta = self.beta up_beta = 2.0 + rN = int(len(self.likelihood_logp) * self.threshold) while up_beta - low_beta > 1e-6: new_beta = (low_beta + up_beta) / 2.0 - log_weights_un = (new_beta - old_beta) * self.likelihood_logp + log_weights_un = (new_beta - old_beta) * self.likelihood_logp #p(theta|y)^CHARLY beta but why old beta is here? log_weights = log_weights_un - logsumexp(log_weights_un) - ESS = int(np.exp(-logsumexp(log_weights * 2))) + + ESS = int(np.exp(-logsumexp(log_weights * 2))) # importance sampling EFF. + + # BISECTION METHOD FOR ESS? if ESS == rN: break elif ESS < rN: up_beta = new_beta else: low_beta = new_beta + if new_beta >= 1: new_beta = 1 log_weights_un = (new_beta - old_beta) * self.likelihood_logp log_weights = log_weights_un - logsumexp(log_weights_un) + # CHARLY why again? self.beta = new_beta self.weights = np.exp(log_weights) @@ -268,6 +276,7 @@ def resample(self): self.tempered_posterior = self.tempered_posterior[self.resampling_indexes] self.prior_logp = self.prior_logp[self.resampling_indexes] self.likelihood_logp = self.likelihood_logp[self.resampling_indexes] + self.tempered_posterior_logp = self.prior_logp + self.likelihood_logp * self.beta def tune(self): @@ -301,6 +310,7 @@ def sample_settings(self) -> Dict: } def _posterior_to_trace(self, chain=0) -> NDArray: + # CHARLY WHY IS THIS PRIVATE? used from sample_smc """Save results into a PyMC trace This method shoud not be overwritten. @@ -387,6 +397,7 @@ def mutate(self): # This variable is updated at the end of the loop with the entries from the accepted # transitions, which is equivalent to recomputing it in every iteration of the loop. backward_logp = self.proposal_dist.logpdf(self.tempered_posterior) + for n_step in range(self.n_steps): proposal = floatX(self.proposal_dist.rvs(size=self.draws, random_state=self.rng)) proposal = proposal.reshape(len(proposal), -1) @@ -395,7 +406,7 @@ def mutate(self): ll = np.array([self.likelihood_logp_func(prop) for prop in proposal]) pl = np.array([self.prior_logp_func(prop) for prop in proposal]) - proposal_logp = pl + ll * self.beta + proposal_logp = pl + ll * self.beta # this accepted = log_R[n_step] < ( (proposal_logp + backward_logp) - (self.tempered_posterior_logp + forward_logp) ) @@ -497,18 +508,19 @@ def tune(self): def mutate(self): """Metropolis-Hastings perturbation.""" ac_ = np.empty((self.n_steps, self.draws)) - log_R = np.log(self.rng.random((self.n_steps, self.draws))) + for n_step in range(self.n_steps): proposal = floatX( self.tempered_posterior + self.proposal_dist(num_draws=self.draws, rng=self.rng) * self.proposal_scales[:, None] ) + ll = np.array([self.likelihood_logp_func(prop) for prop in proposal]) pl = np.array([self.prior_logp_func(prop) for prop in proposal]) - proposal_logp = pl + ll * self.beta + accepted = log_R[n_step] < (proposal_logp - self.tempered_posterior_logp) ac_[n_step] = accepted diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index e862018ba6..8ae9fbb1c8 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -31,18 +31,18 @@ from pymc.aesaraf import floatX from pymc.backends.base import MultiTrace -from pymc.smc.smc import IMH +from pymc.smc.smc import IMH, MH from pymc.tests.helpers import SeededTest, assert_random_state_equal class TestSMC(SeededTest): - """Tests for the default SMC kernel""" + """Tests for the defa ult SMC kernel""" def setup_class(self): super().setup_class() self.samples = 1000 n = 4 - mu1 = np.ones(n) * (1.0 / 2) + mu1 = np.ones(n) * 0.5 mu2 = -mu1 stdev = 0.1 @@ -68,7 +68,7 @@ def two_gaussians(x): with pm.Model() as self.SMC_test: X = pm.Uniform("X", lower=-2, upper=2.0, shape=n) - llk = pm.Potential("muh", two_gaussians(X)) + llk = pm.Potential("muh", two_gaussians(X)) #Wasn't it easier to use a mixture model directly? self.muref = mu1 @@ -80,8 +80,7 @@ def test_sample(self): initial_rng_state = np.random.get_state() with self.SMC_test: mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False) - assert_random_state_equal(initial_rng_state, np.random.get_state()) - + assert_random_state_equal(initial_rng_state, np.random.get_state()) # TODO: why this? maybe to verify that nothing was sampled? x = mtrace["X"] mu1d = np.abs(x).mean(axis=0) np.testing.assert_allclose(self.muref, mu1d, rtol=0.0, atol=0.03) @@ -106,10 +105,20 @@ def test_discrete_rounding_proposal(self): assert np.isclose(smc.prior_logp_func(floatX(np.array([0.51]))), np.log(0.7)) assert smc.prior_logp_func(floatX(np.array([1.51]))) == -np.inf + def test_mh_kernel(self): + with pm.Model() as m: + z = pm.Bernoulli("z", p=0.7) + like = pm.Potential("like", z * 1.0) + + smc = MH(model=m) + smc.tune() + def test_unobserved_discrete(self): + """ + + """ n = 10 rng = self.get_random_state() - z_true = np.zeros(n, dtype=int) z_true[int(n / 2) :] = 1 y = st.norm(np.array([-1, 1])[z_true], 0.25).rvs(random_state=rng) @@ -124,6 +133,10 @@ def test_unobserved_discrete(self): assert np.all(np.median(trace["z"], axis=0) == z_true) def test_marginal_likelihood(self): + """ + Verifies that the log marginal likelihood function + can be correctly computed for a Beta-Bernoulli model. + """ data = np.repeat([1, 0], [50, 50]) marginals = [] a_prior_0, b_prior_0 = 1.0, 1.0 @@ -135,6 +148,7 @@ def test_marginal_likelihood(self): y = pm.Bernoulli("y", a, observed=data) trace = pm.sample_smc(2000, return_inferencedata=False) marginals.append(trace.report.log_marginal_likelihood) + # compare to the analytical result assert abs(np.exp(np.nanmean(marginals[1]) - np.nanmean(marginals[0])) - 4.0) <= 1 @@ -148,6 +162,7 @@ def test_start(self): "b_log__": np.abs(np.random.normal(0, 10, size=500)), } trace = pm.sample_smc(500, chains=1, start=start) + #TODO what's the assertion here? def test_kernel_kwargs(self): with self.fast_model: From dd7ebcce198f8c0b3b64e07b83b453d471550463 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 20 Dec 2021 16:54:06 -0300 Subject: [PATCH 02/10] Refactor chain sampling methods --- pymc/smc/sample_smc.py | 84 ++++++++++++++++-------------------------- pymc/smc/smc.py | 7 ++-- pymc/tests/test_smc.py | 10 ++--- 3 files changed, 40 insertions(+), 61 deletions(-) diff --git a/pymc/smc/sample_smc.py b/pymc/smc/sample_smc.py index 0bc229e4d7..8b8c60e871 100644 --- a/pymc/smc/sample_smc.py +++ b/pymc/smc/sample_smc.py @@ -20,7 +20,7 @@ from collections import defaultdict from collections.abc import Iterable from itertools import repeat - +from pymc.smc.utils import run_chains_parallel, run_chains_sequential import cloudpickle import numpy as np @@ -222,37 +222,11 @@ def sample_smc( ) t1 = time.time() + results = None if cores > 1: - pbar = progress_bar((), total=100, display=progressbar) - pbar.update(0) - pbars = [pbar] + [None] * (chains - 1) - - pool = mp.Pool(cores) - - # "manually" (de)serialize params before/after multiprocessing - params = tuple(cloudpickle.dumps(p) for p in params) - kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} - results = _starmap_with_kwargs( - pool, - _sample_smc_int, - [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)], - repeat(kernel_kwargs), - ) - results = tuple(cloudpickle.loads(r) for r in results) - pool.close() - pool.join() - + results = run_chains_parallel(chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores) else: - results = [] - pbar = progress_bar((), total=100 * chains, display=progressbar) - pbar.update(0) - for chain in range(chains): - pbar.offset = 100 * chain - pbar.base_comment = f"Chain: {chain+1}/{chains}" - results.append( - _sample_smc_int(*params, random_seed[chain], chain, pbar, **kernel_kwargs) - ) - + results = run_chains_sequential(chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs) ( traces, sample_stats, @@ -260,14 +234,22 @@ def sample_smc( ) = zip(*results) trace = MultiTrace(traces) - idata = None - # Save sample_stats _t_sampling = time.time() - t1 + sample_stats, idata = _save_sample_stats(sample_settings, sample_stats, chains, + trace, return_inferencedata, + _t_sampling, idata_kwargs, model) + + if compute_convergence_checks: + _compute_convergence_checks(idata, draws, model, trace) + return idata if return_inferencedata else trace + + +def _save_sample_stats(sample_settings, sample_stats, chains, trace, return_inferencedata, _t_sampling, idata_kwargs, model): sample_settings_dict = sample_settings[0] sample_settings_dict["_t_sampling"] = _t_sampling - sample_stats_dict = sample_stats[0] + if chains > 1: # Collect the stat values from each chain in a single list for stat in sample_stats[0].keys(): @@ -281,6 +263,7 @@ def sample_smc( setattr(trace.report, stat, value) for stat, value in sample_settings_dict.items(): setattr(trace.report, stat, value) + idata = None else: for stat, value in sample_stats_dict.items(): if chains > 1: @@ -303,19 +286,20 @@ def sample_smc( idata = to_inference_data(trace, **ikwargs) idata = InferenceData(**idata, sample_stats=sample_stats) - if compute_convergence_checks: - if draws < 100: - warnings.warn( - "The number of samples is too small to check convergence reliably.", - stacklevel=2, - ) - else: - if idata is None: - idata = to_inference_data(trace, log_likelihood=False) - trace.report._run_convergence_checks(idata, model) - trace.report._log_summary() + return sample_stats, idata - return idata if return_inferencedata else trace + +def _compute_convergence_checks(idata, draws, model, trace): + if draws < 100: + warnings.warn( + "The number of samples is too small to check convergence reliably.", + stacklevel=2, + ) + else: + if idata is None: + idata = to_inference_data(trace, log_likelihood=False) + trace.report._run_convergence_checks(idata, model) + trace.report._log_summary() def _sample_smc_int( @@ -356,7 +340,7 @@ def _sample_smc_int( progressbar.comment = f"{getattr(progressbar, 'base_comment', '')} Stage: 0 Beta: 0" progressbar.update_bar(getattr(progressbar, "offset", 0) + 0) - smc._initialize_kernel() + smc._initialize_kernel() # TODO THIS CALLS A PRIVATE METHOD smc.setup_kernel() stage = 0 @@ -391,12 +375,6 @@ def _sample_smc_int( return results -def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): - # Helper function to allow kwargs with Pool.starmap - # Copied from https://stackoverflow.com/a/53173433/13311693 - args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) - return pool.starmap(_apply_args_and_kwargs, args_for_starmap) -def _apply_args_and_kwargs(fn, args, kwargs): - return fn(*args, **kwargs) + diff --git a/pymc/smc/smc.py b/pymc/smc/smc.py index bf7c32f568..44548d9892 100644 --- a/pymc/smc/smc.py +++ b/pymc/smc/smc.py @@ -232,6 +232,8 @@ def update_beta_and_weights(self): The importance weights based on two successive tempered likelihoods (i.e. two successive values of beta) and updates the marginal likelihood estimate. + + ESS is calculated for importance sampling. BDA 3rd ed. eq 10.4 """ self.iteration += 1 @@ -245,9 +247,8 @@ def update_beta_and_weights(self): log_weights_un = (new_beta - old_beta) * self.likelihood_logp #p(theta|y)^CHARLY beta but why old beta is here? log_weights = log_weights_un - logsumexp(log_weights_un) - ESS = int(np.exp(-logsumexp(log_weights * 2))) # importance sampling EFF. + ESS = int(np.exp(-logsumexp(log_weights * 2))) - # BISECTION METHOD FOR ESS? if ESS == rN: break elif ESS < rN: @@ -406,7 +407,7 @@ def mutate(self): ll = np.array([self.likelihood_logp_func(prop) for prop in proposal]) pl = np.array([self.prior_logp_func(prop) for prop in proposal]) - proposal_logp = pl + ll * self.beta # this + proposal_logp = pl + ll * self.beta accepted = log_R[n_step] < ( (proposal_logp + backward_logp) - (self.tempered_posterior_logp + forward_logp) ) diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index 8ae9fbb1c8..5943fd85bb 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -36,7 +36,7 @@ class TestSMC(SeededTest): - """Tests for the defa ult SMC kernel""" + """Tests for the default SMC kernel""" def setup_class(self): super().setup_class() @@ -54,6 +54,9 @@ def setup_class(self): w2 = 1 - stdev def two_gaussians(x): + """ + Mixture of gaussians likelihood + """ log_like1 = ( -0.5 * n * at.log(2 * np.pi) - 0.5 * at.log(dsigma) @@ -68,7 +71,7 @@ def two_gaussians(x): with pm.Model() as self.SMC_test: X = pm.Uniform("X", lower=-2, upper=2.0, shape=n) - llk = pm.Potential("muh", two_gaussians(X)) #Wasn't it easier to use a mixture model directly? + llk = pm.Potential("muh", two_gaussians(X)) self.muref = mu1 @@ -114,9 +117,6 @@ def test_mh_kernel(self): smc.tune() def test_unobserved_discrete(self): - """ - - """ n = 10 rng = self.get_random_state() z_true = np.zeros(n, dtype=int) From 61bb6d885260b01d0308d9177ed6a0fcb54f6091 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 20 Dec 2021 16:56:21 -0300 Subject: [PATCH 03/10] Create utils.py --- pymc/smc/utils.py | 70 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 pymc/smc/utils.py diff --git a/pymc/smc/utils.py b/pymc/smc/utils.py new file mode 100644 index 0000000000..f333189b8d --- /dev/null +++ b/pymc/smc/utils.py @@ -0,0 +1,70 @@ +import logging +import multiprocessing as mp +import time +import warnings + +from collections import defaultdict +from collections.abc import Iterable +from itertools import repeat + +import cloudpickle +import numpy as np + +from arviz import InferenceData +from fastprogress.fastprogress import progress_bar + +import pymc + +from pymc.backends.arviz import dict_to_dataset, to_inference_data +from pymc.backends.base import MultiTrace +from pymc.model import modelcontext +from pymc.parallel_sampling import _cpu_count +from pymc.smc.smc import IMH + + +def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores): + pbar = progress_bar((), total=100, display=progressbar) + pbar.update(0) + pbars = [pbar] + [None] * (chains - 1) + + pool = mp.Pool(cores) + + # "manually" (de)serialize params before/after multiprocessing + params = tuple(cloudpickle.dumps(p) for p in params) + kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} + results = _starmap_with_kwargs( + pool, + to_run, + [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)], + repeat(kernel_kwargs), + ) + results = tuple(cloudpickle.loads(r) for r in results) + pool.close() + pool.join() + return results + + +def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): + # Helper function to allow kwargs with Pool.starmap + # Copied from https://stackoverflow.com/a/53173433/13311693 + args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) + return pool.starmap(_apply_args_and_kwargs, args_for_starmap) + + +def _apply_args_and_kwargs(fn, args, kwargs): + return fn(*args, **kwargs) + + +def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs): + #TODO maybe reorder parameters + results = [] + pbar = progress_bar((), total=100 * chains, display=progressbar) + pbar.update(0) + for chain in range(chains): + pbar.offset = 100 * chain + pbar.base_comment = f"Chain: {chain + 1}/{chains}" + results.append( + to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs) + ) + return results + From 5848044b516dd0f850875776019d97a6d2d0bb35 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 20 Dec 2021 17:04:45 -0300 Subject: [PATCH 04/10] runners file --- pymc/smc/{utils.py => runners.py} | 19 ------------------- pymc/smc/sample_smc.py | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) rename pymc/smc/{utils.py => runners.py} (79%) diff --git a/pymc/smc/utils.py b/pymc/smc/runners.py similarity index 79% rename from pymc/smc/utils.py rename to pymc/smc/runners.py index f333189b8d..a73296c0d6 100644 --- a/pymc/smc/utils.py +++ b/pymc/smc/runners.py @@ -1,26 +1,8 @@ -import logging import multiprocessing as mp -import time -import warnings - -from collections import defaultdict -from collections.abc import Iterable from itertools import repeat - import cloudpickle -import numpy as np - -from arviz import InferenceData from fastprogress.fastprogress import progress_bar -import pymc - -from pymc.backends.arviz import dict_to_dataset, to_inference_data -from pymc.backends.base import MultiTrace -from pymc.model import modelcontext -from pymc.parallel_sampling import _cpu_count -from pymc.smc.smc import IMH - def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores): pbar = progress_bar((), total=100, display=progressbar) @@ -56,7 +38,6 @@ def _apply_args_and_kwargs(fn, args, kwargs): def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs): - #TODO maybe reorder parameters results = [] pbar = progress_bar((), total=100 * chains, display=progressbar) pbar.update(0) diff --git a/pymc/smc/sample_smc.py b/pymc/smc/sample_smc.py index 8b8c60e871..d6e06966f1 100644 --- a/pymc/smc/sample_smc.py +++ b/pymc/smc/sample_smc.py @@ -222,7 +222,7 @@ def sample_smc( ) t1 = time.time() - results = None + if cores > 1: results = run_chains_parallel(chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores) else: From 6c134be59a46aaf0d9efd14031dadfaf53f95470 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 20 Dec 2021 17:06:00 -0300 Subject: [PATCH 05/10] Update sample_smc.py --- pymc/smc/sample_smc.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pymc/smc/sample_smc.py b/pymc/smc/sample_smc.py index d6e06966f1..305dae7886 100644 --- a/pymc/smc/sample_smc.py +++ b/pymc/smc/sample_smc.py @@ -13,19 +13,16 @@ # limitations under the License. import logging -import multiprocessing as mp import time import warnings from collections import defaultdict from collections.abc import Iterable -from itertools import repeat -from pymc.smc.utils import run_chains_parallel, run_chains_sequential +from pymc.smc.runners import run_chains_parallel, run_chains_sequential import cloudpickle import numpy as np from arviz import InferenceData -from fastprogress.fastprogress import progress_bar import pymc @@ -222,7 +219,7 @@ def sample_smc( ) t1 = time.time() - + if cores > 1: results = run_chains_parallel(chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores) else: From 7b771a033ac39211e61ba5bef90313a3941731d1 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 20 Dec 2021 17:11:31 -0300 Subject: [PATCH 06/10] a --- pymc/smc/smc.py | 5 +---- pymc/tests/test_smc.py | 10 +--------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/pymc/smc/smc.py b/pymc/smc/smc.py index 44548d9892..b588155545 100644 --- a/pymc/smc/smc.py +++ b/pymc/smc/smc.py @@ -246,21 +246,18 @@ def update_beta_and_weights(self): new_beta = (low_beta + up_beta) / 2.0 log_weights_un = (new_beta - old_beta) * self.likelihood_logp #p(theta|y)^CHARLY beta but why old beta is here? log_weights = log_weights_un - logsumexp(log_weights_un) - ESS = int(np.exp(-logsumexp(log_weights * 2))) - if ESS == rN: break elif ESS < rN: up_beta = new_beta else: low_beta = new_beta - if new_beta >= 1: new_beta = 1 log_weights_un = (new_beta - old_beta) * self.likelihood_logp log_weights = log_weights_un - logsumexp(log_weights_un) - # CHARLY why again? + self.beta = new_beta self.weights = np.exp(log_weights) diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index 5943fd85bb..de082ba576 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -108,14 +108,6 @@ def test_discrete_rounding_proposal(self): assert np.isclose(smc.prior_logp_func(floatX(np.array([0.51]))), np.log(0.7)) assert smc.prior_logp_func(floatX(np.array([1.51]))) == -np.inf - def test_mh_kernel(self): - with pm.Model() as m: - z = pm.Bernoulli("z", p=0.7) - like = pm.Potential("like", z * 1.0) - - smc = MH(model=m) - smc.tune() - def test_unobserved_discrete(self): n = 10 rng = self.get_random_state() @@ -162,7 +154,7 @@ def test_start(self): "b_log__": np.abs(np.random.normal(0, 10, size=500)), } trace = pm.sample_smc(500, chains=1, start=start) - #TODO what's the assertion here? + def test_kernel_kwargs(self): with self.fast_model: From 8bc0fa1619ee4d1555bec229f93659abd812e3c4 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 20 Dec 2021 17:20:48 -0300 Subject: [PATCH 07/10] f --- pymc/smc/sample_smc.py | 2 +- pymc/smc/smc.py | 11 +++-------- pymc/tests/test_smc.py | 1 - 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/pymc/smc/sample_smc.py b/pymc/smc/sample_smc.py index 305dae7886..38b16730e0 100644 --- a/pymc/smc/sample_smc.py +++ b/pymc/smc/sample_smc.py @@ -337,7 +337,7 @@ def _sample_smc_int( progressbar.comment = f"{getattr(progressbar, 'base_comment', '')} Stage: 0 Beta: 0" progressbar.update_bar(getattr(progressbar, "offset", 0) + 0) - smc._initialize_kernel() # TODO THIS CALLS A PRIVATE METHOD + smc._initialize_kernel() smc.setup_kernel() stage = 0 diff --git a/pymc/smc/smc.py b/pymc/smc/smc.py index b588155545..91a3a46e3d 100644 --- a/pymc/smc/smc.py +++ b/pymc/smc/smc.py @@ -244,7 +244,7 @@ def update_beta_and_weights(self): while up_beta - low_beta > 1e-6: new_beta = (low_beta + up_beta) / 2.0 - log_weights_un = (new_beta - old_beta) * self.likelihood_logp #p(theta|y)^CHARLY beta but why old beta is here? + log_weights_un = (new_beta - old_beta) * self.likelihood_logp log_weights = log_weights_un - logsumexp(log_weights_un) ESS = int(np.exp(-logsumexp(log_weights * 2))) if ESS == rN: @@ -258,7 +258,6 @@ def update_beta_and_weights(self): log_weights_un = (new_beta - old_beta) * self.likelihood_logp log_weights = log_weights_un - logsumexp(log_weights_un) - self.beta = new_beta self.weights = np.exp(log_weights) # We normalize again to correct for small numerical errors that might build up @@ -308,10 +307,9 @@ def sample_settings(self) -> Dict: } def _posterior_to_trace(self, chain=0) -> NDArray: - # CHARLY WHY IS THIS PRIVATE? used from sample_smc """Save results into a PyMC trace - This method shoud not be overwritten. + This method should not be overwritten. """ lenght_pos = len(self.tempered_posterior) varnames = [v.name for v in self.variables] @@ -395,7 +393,6 @@ def mutate(self): # This variable is updated at the end of the loop with the entries from the accepted # transitions, which is equivalent to recomputing it in every iteration of the loop. backward_logp = self.proposal_dist.logpdf(self.tempered_posterior) - for n_step in range(self.n_steps): proposal = floatX(self.proposal_dist.rvs(size=self.draws, random_state=self.rng)) proposal = proposal.reshape(len(proposal), -1) @@ -507,18 +504,16 @@ def mutate(self): """Metropolis-Hastings perturbation.""" ac_ = np.empty((self.n_steps, self.draws)) log_R = np.log(self.rng.random((self.n_steps, self.draws))) - for n_step in range(self.n_steps): proposal = floatX( self.tempered_posterior + self.proposal_dist(num_draws=self.draws, rng=self.rng) * self.proposal_scales[:, None] ) - ll = np.array([self.likelihood_logp_func(prop) for prop in proposal]) pl = np.array([self.prior_logp_func(prop) for prop in proposal]) - proposal_logp = pl + ll * self.beta + proposal_logp = pl + ll * self.beta accepted = log_R[n_step] < (proposal_logp - self.tempered_posterior_logp) ac_[n_step] = accepted diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index de082ba576..da797ae80e 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -155,7 +155,6 @@ def test_start(self): } trace = pm.sample_smc(500, chains=1, start=start) - def test_kernel_kwargs(self): with self.fast_model: trace = pm.sample_smc( From cc5927817b7a55f22243cd54cf92ed6f09164c3a Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 20 Dec 2021 17:21:50 -0300 Subject: [PATCH 08/10] Update test_smc.py --- pymc/tests/test_smc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index da797ae80e..22c23f4c78 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -31,7 +31,7 @@ from pymc.aesaraf import floatX from pymc.backends.base import MultiTrace -from pymc.smc.smc import IMH, MH +from pymc.smc.smc import IMH from pymc.tests.helpers import SeededTest, assert_random_state_equal From 19b4495b6d43b41271e0a3a8556b0ab011c1406b Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 20 Dec 2021 17:33:22 -0300 Subject: [PATCH 09/10] pre commit running --- pymc/smc/runners.py | 8 ++++---- pymc/smc/sample_smc.py | 40 ++++++++++++++++++++++++++++------------ pymc/tests/test_smc.py | 4 +++- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/pymc/smc/runners.py b/pymc/smc/runners.py index a73296c0d6..367d36481f 100644 --- a/pymc/smc/runners.py +++ b/pymc/smc/runners.py @@ -1,6 +1,9 @@ import multiprocessing as mp + from itertools import repeat + import cloudpickle + from fastprogress.fastprogress import progress_bar @@ -44,8 +47,5 @@ def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kern for chain in range(chains): pbar.offset = 100 * chain pbar.base_comment = f"Chain: {chain + 1}/{chains}" - results.append( - to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs) - ) + results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs)) return results - diff --git a/pymc/smc/sample_smc.py b/pymc/smc/sample_smc.py index 38b16730e0..1a09dfe695 100644 --- a/pymc/smc/sample_smc.py +++ b/pymc/smc/sample_smc.py @@ -18,7 +18,7 @@ from collections import defaultdict from collections.abc import Iterable -from pymc.smc.runners import run_chains_parallel, run_chains_sequential + import cloudpickle import numpy as np @@ -30,6 +30,7 @@ from pymc.backends.base import MultiTrace from pymc.model import modelcontext from pymc.parallel_sampling import _cpu_count +from pymc.smc.runners import run_chains_parallel, run_chains_sequential from pymc.smc.smc import IMH @@ -221,9 +222,13 @@ def sample_smc( t1 = time.time() if cores > 1: - results = run_chains_parallel(chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores) + results = run_chains_parallel( + chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores + ) else: - results = run_chains_sequential(chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs) + results = run_chains_sequential( + chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs + ) ( traces, sample_stats, @@ -233,16 +238,32 @@ def sample_smc( trace = MultiTrace(traces) _t_sampling = time.time() - t1 - sample_stats, idata = _save_sample_stats(sample_settings, sample_stats, chains, - trace, return_inferencedata, - _t_sampling, idata_kwargs, model) + sample_stats, idata = _save_sample_stats( + sample_settings, + sample_stats, + chains, + trace, + return_inferencedata, + _t_sampling, + idata_kwargs, + model, + ) if compute_convergence_checks: _compute_convergence_checks(idata, draws, model, trace) return idata if return_inferencedata else trace -def _save_sample_stats(sample_settings, sample_stats, chains, trace, return_inferencedata, _t_sampling, idata_kwargs, model): +def _save_sample_stats( + sample_settings, + sample_stats, + chains, + trace, + return_inferencedata, + _t_sampling, + idata_kwargs, + model, +): sample_settings_dict = sample_settings[0] sample_settings_dict["_t_sampling"] = _t_sampling sample_stats_dict = sample_stats[0] @@ -370,8 +391,3 @@ def _sample_smc_int( results = cloudpickle.dumps(results) return results - - - - - diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index 22c23f4c78..ff24399473 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -83,7 +83,9 @@ def test_sample(self): initial_rng_state = np.random.get_state() with self.SMC_test: mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False) - assert_random_state_equal(initial_rng_state, np.random.get_state()) # TODO: why this? maybe to verify that nothing was sampled? + assert_random_state_equal( + initial_rng_state, np.random.get_state() + ) # TODO: why this? maybe to verify that nothing was sampled? x = mtrace["X"] mu1d = np.abs(x).mean(axis=0) np.testing.assert_allclose(self.muref, mu1d, rtol=0.0, atol=0.03) From e822eea7de34c6d822619f2b24461323bb2c4e47 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Tue, 21 Dec 2021 11:50:29 -0300 Subject: [PATCH 10/10] reorder --- pymc/smc/runners.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pymc/smc/runners.py b/pymc/smc/runners.py index 367d36481f..2d45fcbf70 100644 --- a/pymc/smc/runners.py +++ b/pymc/smc/runners.py @@ -29,17 +29,6 @@ def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel return results -def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): - # Helper function to allow kwargs with Pool.starmap - # Copied from https://stackoverflow.com/a/53173433/13311693 - args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) - return pool.starmap(_apply_args_and_kwargs, args_for_starmap) - - -def _apply_args_and_kwargs(fn, args, kwargs): - return fn(*args, **kwargs) - - def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs): results = [] pbar = progress_bar((), total=100 * chains, display=progressbar) @@ -49,3 +38,14 @@ def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kern pbar.base_comment = f"Chain: {chain + 1}/{chains}" results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs)) return results + + +def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): + # Helper function to allow kwargs with Pool.starmap + # Copied from https://stackoverflow.com/a/53173433/13311693 + args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) + return pool.starmap(_apply_args_and_kwargs, args_for_starmap) + + +def _apply_args_and_kwargs(fn, args, kwargs): + return fn(*args, **kwargs)