Skip to content

Refactor of Sequential Monte Carlo internals #5274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions pymc/smc/runners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import multiprocessing as mp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this needs to be in its own separate file. The original file is pretty concise

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right I didn't extracted due to size, but because it feels that running parallel vs sequential is kind of independent of SMC, and that it could be reused somewhere else. LMK if its enough reason to be outside or I move it again in.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a similar need emerges on another sampler we can refactor it then. Doing it now seems like some premature refactoring.

Also it wouldn't make sense to host it inside the SMC module


from itertools import repeat

import cloudpickle

from fastprogress.fastprogress import progress_bar


def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores):
pbar = progress_bar((), total=100, display=progressbar)
pbar.update(0)
pbars = [pbar] + [None] * (chains - 1)

pool = mp.Pool(cores)

# "manually" (de)serialize params before/after multiprocessing
params = tuple(cloudpickle.dumps(p) for p in params)
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
results = _starmap_with_kwargs(
pool,
to_run,
[(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
repeat(kernel_kwargs),
)
results = tuple(cloudpickle.loads(r) for r in results)
pool.close()
pool.join()
return results


def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs):
results = []
pbar = progress_bar((), total=100 * chains, display=progressbar)
pbar.update(0)
for chain in range(chains):
pbar.offset = 100 * chain
pbar.base_comment = f"Chain: {chain + 1}/{chains}"
results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs))
return results


def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
# Helper function to allow kwargs with Pool.starmap
# Copied from https://stackoverflow.com/a/53173433/13311693
args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter)
return pool.starmap(_apply_args_and_kwargs, args_for_starmap)


def _apply_args_and_kwargs(fn, args, kwargs):
return fn(*args, **kwargs)
105 changes: 48 additions & 57 deletions pymc/smc/sample_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,24 @@
# limitations under the License.

import logging
import multiprocessing as mp
import time
import warnings

from collections import defaultdict
from collections.abc import Iterable
from itertools import repeat

import cloudpickle
import numpy as np

from arviz import InferenceData
from fastprogress.fastprogress import progress_bar

import pymc

from pymc.backends.arviz import dict_to_dataset, to_inference_data
from pymc.backends.base import MultiTrace
from pymc.model import modelcontext
from pymc.parallel_sampling import _cpu_count
from pymc.smc.runners import run_chains_parallel, run_chains_sequential
from pymc.smc.smc import IMH


Expand Down Expand Up @@ -222,52 +220,54 @@ def sample_smc(
)

t1 = time.time()

if cores > 1:
pbar = progress_bar((), total=100, display=progressbar)
pbar.update(0)
pbars = [pbar] + [None] * (chains - 1)

pool = mp.Pool(cores)

# "manually" (de)serialize params before/after multiprocessing
params = tuple(cloudpickle.dumps(p) for p in params)
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
results = _starmap_with_kwargs(
pool,
_sample_smc_int,
[(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
repeat(kernel_kwargs),
results = run_chains_parallel(
chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
)
results = tuple(cloudpickle.loads(r) for r in results)
pool.close()
pool.join()

else:
results = []
pbar = progress_bar((), total=100 * chains, display=progressbar)
pbar.update(0)
for chain in range(chains):
pbar.offset = 100 * chain
pbar.base_comment = f"Chain: {chain+1}/{chains}"
results.append(
_sample_smc_int(*params, random_seed[chain], chain, pbar, **kernel_kwargs)
)

results = run_chains_sequential(
chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
)
(
traces,
sample_stats,
sample_settings,
) = zip(*results)

trace = MultiTrace(traces)
idata = None

# Save sample_stats
_t_sampling = time.time() - t1
sample_stats, idata = _save_sample_stats(
sample_settings,
sample_stats,
chains,
trace,
return_inferencedata,
_t_sampling,
idata_kwargs,
model,
)

if compute_convergence_checks:
_compute_convergence_checks(idata, draws, model, trace)
return idata if return_inferencedata else trace


def _save_sample_stats(
sample_settings,
sample_stats,
chains,
trace,
return_inferencedata,
_t_sampling,
idata_kwargs,
model,
):
sample_settings_dict = sample_settings[0]
sample_settings_dict["_t_sampling"] = _t_sampling

sample_stats_dict = sample_stats[0]

if chains > 1:
# Collect the stat values from each chain in a single list
for stat in sample_stats[0].keys():
Expand All @@ -281,6 +281,7 @@ def sample_smc(
setattr(trace.report, stat, value)
for stat, value in sample_settings_dict.items():
setattr(trace.report, stat, value)
idata = None
else:
for stat, value in sample_stats_dict.items():
if chains > 1:
Expand All @@ -303,19 +304,20 @@ def sample_smc(
idata = to_inference_data(trace, **ikwargs)
idata = InferenceData(**idata, sample_stats=sample_stats)

if compute_convergence_checks:
if draws < 100:
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
else:
if idata is None:
idata = to_inference_data(trace, log_likelihood=False)
trace.report._run_convergence_checks(idata, model)
trace.report._log_summary()
return sample_stats, idata

return idata if return_inferencedata else trace

def _compute_convergence_checks(idata, draws, model, trace):
if draws < 100:
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
else:
if idata is None:
idata = to_inference_data(trace, log_likelihood=False)
trace.report._run_convergence_checks(idata, model)
trace.report._log_summary()


def _sample_smc_int(
Expand Down Expand Up @@ -389,14 +391,3 @@ def _sample_smc_int(
results = cloudpickle.dumps(results)

return results


def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
# Helper function to allow kwargs with Pool.starmap
# Copied from https://stackoverflow.com/a/53173433/13311693
args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter)
return pool.starmap(_apply_args_and_kwargs, args_for_starmap)


def _apply_args_and_kwargs(fn, args, kwargs):
return fn(*args, **kwargs)
13 changes: 9 additions & 4 deletions pymc/smc/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def __init__(

self.draws = draws
self.start = start
if threshold < 0 or threshold > 1:
raise ValueError(f"Threshold value {threshold} must be between 0 and 1")
self.threshold = threshold
self.model = model
self.rng = np.random.default_rng(seed=random_seed)
Expand Down Expand Up @@ -192,7 +194,6 @@ def _initialize_kernel(self):
initial_point = self.model.recompute_initial_point(seed=self.rng.integers(2 ** 30))
for v in self.variables:
self.var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size)

# Create particles bijection map
if self.start:
init_rnd = self.start
Expand All @@ -203,6 +204,7 @@ def _initialize_kernel(self):
for i in range(self.draws):
point = Point({v.name: init_rnd[v.name][i] for v in self.variables}, model=self.model)
population.append(DictToArrayBijection.map(point).data)

self.tempered_posterior = np.array(floatX(population))

# Initialize prior and likelihood log probabilities
Expand All @@ -228,13 +230,16 @@ def setup_kernel(self):
def update_beta_and_weights(self):
"""Calculate the next inverse temperature (beta)

The importance weights based on two sucesive tempered likelihoods (i.e.
The importance weights based on two successive tempered likelihoods (i.e.
two successive values of beta) and updates the marginal likelihood estimate.

ESS is calculated for importance sampling. BDA 3rd ed. eq 10.4
"""
self.iteration += 1

low_beta = old_beta = self.beta
up_beta = 2.0

rN = int(len(self.likelihood_logp) * self.threshold)

while up_beta - low_beta > 1e-6:
Expand Down Expand Up @@ -268,6 +273,7 @@ def resample(self):
self.tempered_posterior = self.tempered_posterior[self.resampling_indexes]
self.prior_logp = self.prior_logp[self.resampling_indexes]
self.likelihood_logp = self.likelihood_logp[self.resampling_indexes]

self.tempered_posterior_logp = self.prior_logp + self.likelihood_logp * self.beta

def tune(self):
Expand Down Expand Up @@ -303,7 +309,7 @@ def sample_settings(self) -> Dict:
def _posterior_to_trace(self, chain=0) -> NDArray:
"""Save results into a PyMC trace

This method shoud not be overwritten.
This method should not be overwritten.
"""
lenght_pos = len(self.tempered_posterior)
varnames = [v.name for v in self.variables]
Expand Down Expand Up @@ -497,7 +503,6 @@ def tune(self):
def mutate(self):
"""Metropolis-Hastings perturbation."""
ac_ = np.empty((self.n_steps, self.draws))

log_R = np.log(self.rng.random((self.n_steps, self.draws)))
for n_step in range(self.n_steps):
proposal = floatX(
Expand Down
16 changes: 12 additions & 4 deletions pymc/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setup_class(self):
super().setup_class()
self.samples = 1000
n = 4
mu1 = np.ones(n) * (1.0 / 2)
mu1 = np.ones(n) * 0.5
mu2 = -mu1

stdev = 0.1
Expand All @@ -54,6 +54,9 @@ def setup_class(self):
w2 = 1 - stdev

def two_gaussians(x):
"""
Mixture of gaussians likelihood
"""
log_like1 = (
-0.5 * n * at.log(2 * np.pi)
- 0.5 * at.log(dsigma)
Expand All @@ -80,8 +83,9 @@ def test_sample(self):
initial_rng_state = np.random.get_state()
with self.SMC_test:
mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False)
assert_random_state_equal(initial_rng_state, np.random.get_state())

assert_random_state_equal(
initial_rng_state, np.random.get_state()
) # TODO: why this? maybe to verify that nothing was sampled?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR where this test was added provides some context: #5131

Since it was not obvious why it was there, a comment before the assert explaining the purpose of the check could be helpful.

x = mtrace["X"]
mu1d = np.abs(x).mean(axis=0)
np.testing.assert_allclose(self.muref, mu1d, rtol=0.0, atol=0.03)
Expand Down Expand Up @@ -109,7 +113,6 @@ def test_discrete_rounding_proposal(self):
def test_unobserved_discrete(self):
n = 10
rng = self.get_random_state()

z_true = np.zeros(n, dtype=int)
z_true[int(n / 2) :] = 1
y = st.norm(np.array([-1, 1])[z_true], 0.25).rvs(random_state=rng)
Expand All @@ -124,6 +127,10 @@ def test_unobserved_discrete(self):
assert np.all(np.median(trace["z"], axis=0) == z_true)

def test_marginal_likelihood(self):
"""
Verifies that the log marginal likelihood function
can be correctly computed for a Beta-Bernoulli model.
"""
data = np.repeat([1, 0], [50, 50])
marginals = []
a_prior_0, b_prior_0 = 1.0, 1.0
Expand All @@ -135,6 +142,7 @@ def test_marginal_likelihood(self):
y = pm.Bernoulli("y", a, observed=data)
trace = pm.sample_smc(2000, return_inferencedata=False)
marginals.append(trace.report.log_marginal_likelihood)

# compare to the analytical result
assert abs(np.exp(np.nanmean(marginals[1]) - np.nanmean(marginals[0])) - 4.0) <= 1

Expand Down