Skip to content

Commit 69a4e60

Browse files
committed
Add SMC Metropolis Kernel
1 parent 2257c34 commit 69a4e60

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

pymc3/smc/smc.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,92 @@ def sample_settings(self):
427427
return stats
428428

429429

430+
class MH(SMC_KERNEL):
431+
"""Metropolis-Hastings SMC kernel"""
432+
433+
def __init__(self, *args, n_steps=25, **kwargs):
434+
"""
435+
Parameters
436+
----------
437+
n_steps: int
438+
The number of steps of each Markov Chain.
439+
"""
440+
super().__init__(*args, **kwargs)
441+
self.n_steps = n_steps
442+
443+
self.proposal_dist = None
444+
self.proposal_scales = None
445+
self.chain_acc_rate = None
446+
447+
def setup_kernel(self):
448+
"""Proposal dist is just a Multivariate Normal with unit identity covariance.
449+
Dimension specific scaling is provided by self.proposal_scales and set in self.tune()
450+
"""
451+
ndim = self.tempered_posterior.shape[1]
452+
self.proposal_dist = multivariate_normal(
453+
mean=np.zeros(ndim),
454+
cov=np.eye(ndim),
455+
)
456+
self.proposal_scales = np.full(self.draws, min(1, 2.38 ** 2 / ndim))
457+
458+
def resample(self):
459+
super().resample()
460+
if self.iteration > 1:
461+
self.proposal_scales = self.proposal_scales[self.resampling_indexes]
462+
self.chain_acc_rate = self.chain_acc_rate[self.resampling_indexes]
463+
464+
def tune(self):
465+
"""Update proposal scales for each particle dimension"""
466+
if self.iteration > 1:
467+
# Rescale based on distance to 0.234 acceptance rate
468+
chain_scales = np.exp(np.log(self.proposal_scales) + (self.chain_acc_rate - 0.234))
469+
# Interpolate between individual and population scales
470+
self.proposal_scales = 0.5 * chain_scales + 0.5 * chain_scales.mean()
471+
472+
def mutate(self):
473+
"""Metropolis-Hastings perturbation."""
474+
ac_ = np.empty((self.n_steps, self.draws))
475+
476+
log_R = np.log(np.random.rand(self.n_steps, self.draws))
477+
for n_step in range(self.n_steps):
478+
proposal = floatX(
479+
self.tempered_posterior
480+
+ self.proposal_dist.rvs(size=self.draws) * self.proposal_scales[:, None]
481+
)
482+
ll = np.array([self.likelihood_logp_func(prop) for prop in proposal])
483+
pl = np.array([self.prior_logp_func(prop) for prop in proposal])
484+
485+
proposal_logp = pl + ll * self.beta
486+
accepted = log_R[n_step] < (proposal_logp - self.tempered_posterior_logp)
487+
488+
ac_[n_step] = accepted
489+
self.tempered_posterior[accepted] = proposal[accepted]
490+
self.prior_logp[accepted] = pl[accepted]
491+
self.likelihood_logp[accepted] = ll[accepted]
492+
self.tempered_posterior_logp[accepted] = proposal_logp[accepted]
493+
494+
self.chain_acc_rate = np.mean(ac_, axis=0)
495+
496+
def sample_stats(self):
497+
stats = super().sample_stats()
498+
stats.update(
499+
{
500+
"mean_accept_rate": self.chain_acc_rate.mean(),
501+
"mean_proposal_scale": self.proposal_scales.mean(),
502+
}
503+
)
504+
return stats
505+
506+
def sample_settings(self):
507+
stats = super().sample_settings()
508+
stats.update(
509+
{
510+
"_n_tune": self.n_steps, # Default property name used in `SamplerReport`
511+
}
512+
)
513+
return stats
514+
515+
430516
def _logp_forw(point, out_vars, vars, shared):
431517
"""Compile Aesara function of the model and the input and output variables.
432518

pymc3/tests/test_smc.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,17 @@ def normal_sim(a, b):
329329
)
330330
with pytest.raises(NotImplementedError, match="named models"):
331331
pm.sample_smc(draws=10, kernel="ABC")
332+
333+
334+
class TestMHKernel(SeededTest):
335+
def test_normal_model(self):
336+
data = st.norm(10, 0.5).rvs(1000, random_state=self.get_random_state())
337+
with pm.Model() as m:
338+
mu = pm.Normal("mu", 0, 3)
339+
sigma = pm.HalfNormal("sigma", 1)
340+
y = pm.Normal("y", mu, sigma, observed=data)
341+
idata = pm.sample_smc(draws=2000, kernel=pm.smc.MH)
342+
343+
post = idata.posterior.stack(sample=("chain", "draw"))
344+
assert np.abs(post["mu"].mean() - 10) < 0.1
345+
assert np.abs(post["sigma"].mean() - 0.5) < 0.05

0 commit comments

Comments
 (0)