Skip to content

Commit e6fc2ec

Browse files
ciguaranricardoV94
andauthored
Refactor of Sequential Monte Carlo internals (#5281)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 29720d0 commit e6fc2ec

File tree

3 files changed

+100
-50
lines changed

3 files changed

+100
-50
lines changed

pymc/smc/sample_smc.py

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -222,52 +222,54 @@ def sample_smc(
222222
)
223223

224224
t1 = time.time()
225+
225226
if cores > 1:
226-
pbar = progress_bar((), total=100, display=progressbar)
227-
pbar.update(0)
228-
pbars = [pbar] + [None] * (chains - 1)
229-
230-
pool = mp.Pool(cores)
231-
232-
# "manually" (de)serialize params before/after multiprocessing
233-
params = tuple(cloudpickle.dumps(p) for p in params)
234-
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
235-
results = _starmap_with_kwargs(
236-
pool,
237-
_sample_smc_int,
238-
[(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
239-
repeat(kernel_kwargs),
227+
results = run_chains_parallel(
228+
chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
240229
)
241-
results = tuple(cloudpickle.loads(r) for r in results)
242-
pool.close()
243-
pool.join()
244-
245230
else:
246-
results = []
247-
pbar = progress_bar((), total=100 * chains, display=progressbar)
248-
pbar.update(0)
249-
for chain in range(chains):
250-
pbar.offset = 100 * chain
251-
pbar.base_comment = f"Chain: {chain+1}/{chains}"
252-
results.append(
253-
_sample_smc_int(*params, random_seed[chain], chain, pbar, **kernel_kwargs)
254-
)
255-
231+
results = run_chains_sequential(
232+
chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
233+
)
256234
(
257235
traces,
258236
sample_stats,
259237
sample_settings,
260238
) = zip(*results)
261239

262240
trace = MultiTrace(traces)
263-
idata = None
264241

265-
# Save sample_stats
266242
_t_sampling = time.time() - t1
243+
sample_stats, idata = _save_sample_stats(
244+
sample_settings,
245+
sample_stats,
246+
chains,
247+
trace,
248+
return_inferencedata,
249+
_t_sampling,
250+
idata_kwargs,
251+
model,
252+
)
253+
254+
if compute_convergence_checks:
255+
_compute_convergence_checks(idata, draws, model, trace)
256+
return idata if return_inferencedata else trace
257+
258+
259+
def _save_sample_stats(
260+
sample_settings,
261+
sample_stats,
262+
chains,
263+
trace,
264+
return_inferencedata,
265+
_t_sampling,
266+
idata_kwargs,
267+
model,
268+
):
267269
sample_settings_dict = sample_settings[0]
268270
sample_settings_dict["_t_sampling"] = _t_sampling
269-
270271
sample_stats_dict = sample_stats[0]
272+
271273
if chains > 1:
272274
# Collect the stat values from each chain in a single list
273275
for stat in sample_stats[0].keys():
@@ -281,6 +283,7 @@ def sample_smc(
281283
setattr(trace.report, stat, value)
282284
for stat, value in sample_settings_dict.items():
283285
setattr(trace.report, stat, value)
286+
idata = None
284287
else:
285288
for stat, value in sample_stats_dict.items():
286289
if chains > 1:
@@ -303,19 +306,20 @@ def sample_smc(
303306
idata = to_inference_data(trace, **ikwargs)
304307
idata = InferenceData(**idata, sample_stats=sample_stats)
305308

306-
if compute_convergence_checks:
307-
if draws < 100:
308-
warnings.warn(
309-
"The number of samples is too small to check convergence reliably.",
310-
stacklevel=2,
311-
)
312-
else:
313-
if idata is None:
314-
idata = to_inference_data(trace, log_likelihood=False)
315-
trace.report._run_convergence_checks(idata, model)
316-
trace.report._log_summary()
309+
return sample_stats, idata
317310

318-
return idata if return_inferencedata else trace
311+
312+
def _compute_convergence_checks(idata, draws, model, trace):
313+
if draws < 100:
314+
warnings.warn(
315+
"The number of samples is too small to check convergence reliably.",
316+
stacklevel=2,
317+
)
318+
else:
319+
if idata is None:
320+
idata = to_inference_data(trace, log_likelihood=False)
321+
trace.report._run_convergence_checks(idata, model)
322+
trace.report._log_summary()
319323

320324

321325
def _sample_smc_int(
@@ -391,6 +395,39 @@ def _sample_smc_int(
391395
return results
392396

393397

398+
def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores):
399+
pbar = progress_bar((), total=100, display=progressbar)
400+
pbar.update(0)
401+
pbars = [pbar] + [None] * (chains - 1)
402+
403+
pool = mp.Pool(cores)
404+
405+
# "manually" (de)serialize params before/after multiprocessing
406+
params = tuple(cloudpickle.dumps(p) for p in params)
407+
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
408+
results = _starmap_with_kwargs(
409+
pool,
410+
to_run,
411+
[(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
412+
repeat(kernel_kwargs),
413+
)
414+
results = tuple(cloudpickle.loads(r) for r in results)
415+
pool.close()
416+
pool.join()
417+
return results
418+
419+
420+
def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs):
421+
results = []
422+
pbar = progress_bar((), total=100 * chains, display=progressbar)
423+
pbar.update(0)
424+
for chain in range(chains):
425+
pbar.offset = 100 * chain
426+
pbar.base_comment = f"Chain: {chain + 1}/{chains}"
427+
results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs))
428+
return results
429+
430+
394431
def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
395432
# Helper function to allow kwargs with Pool.starmap
396433
# Copied from https://stackoverflow.com/a/53173433/13311693

pymc/smc/smc.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def __init__(
151151

152152
self.draws = draws
153153
self.start = start
154+
if threshold < 0 or threshold > 1:
155+
raise ValueError(f"Threshold value {threshold} must be between 0 and 1")
154156
self.threshold = threshold
155157
self.model = model
156158
self.rng = np.random.default_rng(seed=random_seed)
@@ -192,7 +194,6 @@ def _initialize_kernel(self):
192194
initial_point = self.model.recompute_initial_point(seed=self.rng.integers(2 ** 30))
193195
for v in self.variables:
194196
self.var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size)
195-
196197
# Create particles bijection map
197198
if self.start:
198199
init_rnd = self.start
@@ -203,6 +204,7 @@ def _initialize_kernel(self):
203204
for i in range(self.draws):
204205
point = Point({v.name: init_rnd[v.name][i] for v in self.variables}, model=self.model)
205206
population.append(DictToArrayBijection.map(point).data)
207+
206208
self.tempered_posterior = np.array(floatX(population))
207209

208210
# Initialize prior and likelihood log probabilities
@@ -228,13 +230,16 @@ def setup_kernel(self):
228230
def update_beta_and_weights(self):
229231
"""Calculate the next inverse temperature (beta)
230232
231-
The importance weights based on two sucesive tempered likelihoods (i.e.
233+
The importance weights based on two successive tempered likelihoods (i.e.
232234
two successive values of beta) and updates the marginal likelihood estimate.
235+
236+
ESS is calculated for importance sampling. BDA 3rd ed. eq 10.4
233237
"""
234238
self.iteration += 1
235239

236240
low_beta = old_beta = self.beta
237241
up_beta = 2.0
242+
238243
rN = int(len(self.likelihood_logp) * self.threshold)
239244

240245
while up_beta - low_beta > 1e-6:
@@ -268,6 +273,7 @@ def resample(self):
268273
self.tempered_posterior = self.tempered_posterior[self.resampling_indexes]
269274
self.prior_logp = self.prior_logp[self.resampling_indexes]
270275
self.likelihood_logp = self.likelihood_logp[self.resampling_indexes]
276+
271277
self.tempered_posterior_logp = self.prior_logp + self.likelihood_logp * self.beta
272278

273279
def tune(self):
@@ -303,7 +309,7 @@ def sample_settings(self) -> Dict:
303309
def _posterior_to_trace(self, chain=0) -> NDArray:
304310
"""Save results into a PyMC trace
305311
306-
This method shoud not be overwritten.
312+
This method should not be overwritten.
307313
"""
308314
lenght_pos = len(self.tempered_posterior)
309315
varnames = [v.name for v in self.variables]
@@ -497,7 +503,6 @@ def tune(self):
497503
def mutate(self):
498504
"""Metropolis-Hastings perturbation."""
499505
ac_ = np.empty((self.n_steps, self.draws))
500-
501506
log_R = np.log(self.rng.random((self.n_steps, self.draws)))
502507
for n_step in range(self.n_steps):
503508
proposal = floatX(

pymc/tests/test_smc.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def setup_class(self):
4242
super().setup_class()
4343
self.samples = 1000
4444
n = 4
45-
mu1 = np.ones(n) * (1.0 / 2)
45+
mu1 = np.ones(n) * 0.5
4646
mu2 = -mu1
4747

4848
stdev = 0.1
@@ -54,6 +54,9 @@ def setup_class(self):
5454
w2 = 1 - stdev
5555

5656
def two_gaussians(x):
57+
"""
58+
Mixture of gaussians likelihood
59+
"""
5760
log_like1 = (
5861
-0.5 * n * at.log(2 * np.pi)
5962
- 0.5 * at.log(dsigma)
@@ -80,8 +83,9 @@ def test_sample(self):
8083
initial_rng_state = np.random.get_state()
8184
with self.SMC_test:
8285
mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False)
83-
assert_random_state_equal(initial_rng_state, np.random.get_state())
8486

87+
# Verify sampling was done with a non-global random generator
88+
assert_random_state_equal(initial_rng_state, np.random.get_state())
8589
x = mtrace["X"]
8690
mu1d = np.abs(x).mean(axis=0)
8791
np.testing.assert_allclose(self.muref, mu1d, rtol=0.0, atol=0.03)
@@ -109,7 +113,6 @@ def test_discrete_rounding_proposal(self):
109113
def test_unobserved_discrete(self):
110114
n = 10
111115
rng = self.get_random_state()
112-
113116
z_true = np.zeros(n, dtype=int)
114117
z_true[int(n / 2) :] = 1
115118
y = st.norm(np.array([-1, 1])[z_true], 0.25).rvs(random_state=rng)
@@ -124,6 +127,10 @@ def test_unobserved_discrete(self):
124127
assert np.all(np.median(trace["z"], axis=0) == z_true)
125128

126129
def test_marginal_likelihood(self):
130+
"""
131+
Verifies that the log marginal likelihood function
132+
can be correctly computed for a Beta-Bernoulli model.
133+
"""
127134
data = np.repeat([1, 0], [50, 50])
128135
marginals = []
129136
a_prior_0, b_prior_0 = 1.0, 1.0
@@ -135,6 +142,7 @@ def test_marginal_likelihood(self):
135142
y = pm.Bernoulli("y", a, observed=data)
136143
trace = pm.sample_smc(2000, return_inferencedata=False)
137144
marginals.append(trace.report.log_marginal_likelihood)
145+
138146
# compare to the analytical result
139147
assert abs(np.exp(np.nanmean(marginals[1]) - np.nanmean(marginals[0])) - 4.0) <= 1
140148

0 commit comments

Comments
 (0)