Skip to content

Commit ffc3709

Browse files
authored
SMC: Use correlation to autotune n_steps (#5529)
1 parent 71480e0 commit ffc3709

File tree

2 files changed

+68
-83
lines changed

2 files changed

+68
-83
lines changed

pymc/smc/smc.py

Lines changed: 63 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -337,41 +337,21 @@ def _posterior_to_trace(self, chain=0) -> NDArray:
337337
class IMH(SMC_KERNEL):
338338
"""Independent Metropolis-Hastings SMC kernel"""
339339

340-
def __init__(self, *args, n_steps=25, tune_steps=True, p_acc_rate=0.85, **kwargs):
340+
def __init__(self, *args, correlation_threshold=0.01, **kwargs):
341341
"""
342342
Parameters
343343
----------
344-
n_steps: int
345-
The number of steps of each Markov Chain. If ``tune_steps == True`` ``n_steps`` will be used
346-
for the first stage and for the others it will be determined automatically based on the
347-
acceptance rate and `p_acc_rate`, the max number of steps is ``n_steps``.
348-
tune_steps: bool
349-
Whether to compute the number of steps automatically or not. Defaults to True
350-
p_acc_rate: float
351-
Used to compute ``n_steps`` when ``tune_steps == True``. The higher the value of
352-
``p_acc_rate`` the higher the number of steps computed automatically. Defaults to 0.85.
353-
It should be between 0 and 1.
344+
correlation_threshold: float
345+
The lower the value the higher the number of IMH steps computed automatically.
346+
Defaults to 0.01. It should be between 0 and 1.
354347
"""
355348
super().__init__(*args, **kwargs)
356-
self.n_steps = n_steps
357-
self.tune_steps = tune_steps
358-
self.p_acc_rate = p_acc_rate
349+
self.correlation_threshold = correlation_threshold
359350

360-
self.max_steps = n_steps
361-
self.proposed = self.draws * self.n_steps
362351
self.proposal_dist = None
363352
self.acc_rate = None
364353

365354
def tune(self):
366-
# Tune n_steps based on the acceptance rate (skip in first iteration)
367-
if self.tune_steps and self.iteration > 1:
368-
acc_rate = max(1.0 / self.proposed, self.acc_rate)
369-
self.n_steps = min(
370-
self.max_steps,
371-
max(2, int(np.log(1 - self.p_acc_rate) / np.log(1 - acc_rate))),
372-
)
373-
self.proposed = self.draws * self.n_steps
374-
375355
# Update MVNormal proposal based on the mean and covariance of the
376356
# tempered posterior.
377357
cov = np.cov(self.tempered_posterior, ddof=0, rowvar=0)
@@ -384,34 +364,42 @@ def tune(self):
384364

385365
def mutate(self):
386366
"""Independent Metropolis-Hastings perturbation."""
387-
ac_ = np.empty((self.n_steps, self.draws))
388-
log_R = np.log(self.rng.random((self.n_steps, self.draws)))
389-
390-
# The proposal is independent from the current point.
391-
# We have to take that into account to compute the Metropolis-Hastings acceptance
392-
# We first compute the logp of proposing a transition to the current points.
393-
# This variable is updated at the end of the loop with the entries from the accepted
394-
# transitions, which is equivalent to recomputing it in every iteration of the loop.
395-
backward_logp = self.proposal_dist.logpdf(self.tempered_posterior)
396-
for n_step in range(self.n_steps):
367+
self.n_steps = 1
368+
old_corr = 2
369+
corr = Pearson(self.tempered_posterior)
370+
ac_ = []
371+
while True:
372+
log_R = np.log(self.rng.random(self.draws))
373+
# The proposal is independent from the current point.
374+
# We have to take that into account to compute the Metropolis-Hastings acceptance
375+
# We first compute the logp of proposing a transition to the current points.
376+
# This variable is updated at the end of the loop with the entries from the accepted
377+
# transitions, which is equivalent to recomputing it in every iteration of the loop.
397378
proposal = floatX(self.proposal_dist.rvs(size=self.draws, random_state=self.rng))
398379
proposal = proposal.reshape(len(proposal), -1)
399-
# We then compute the logp of proposing a transition to the new points
380+
# To do that we compute the logp of moving to a new point
400381
forward_logp = self.proposal_dist.logpdf(proposal)
401-
382+
# And to going back from that new point
383+
backward_logp = self.proposal_dist.logpdf(self.tempered_posterior)
402384
ll = np.array([self.likelihood_logp_func(prop) for prop in proposal])
403385
pl = np.array([self.prior_logp_func(prop) for prop in proposal])
404386
proposal_logp = pl + ll * self.beta
405-
accepted = log_R[n_step] < (
387+
accepted = log_R < (
406388
(proposal_logp + backward_logp) - (self.tempered_posterior_logp + forward_logp)
407389
)
408390

409-
ac_[n_step] = accepted
410391
self.tempered_posterior[accepted] = proposal[accepted]
411392
self.tempered_posterior_logp[accepted] = proposal_logp[accepted]
412393
self.prior_logp[accepted] = pl[accepted]
413394
self.likelihood_logp[accepted] = ll[accepted]
414-
backward_logp[accepted] = forward_logp[accepted]
395+
ac_.append(accepted)
396+
self.n_steps += 1
397+
398+
pearson_r = corr.get(self.tempered_posterior)
399+
if np.mean((old_corr - pearson_r) > self.correlation_threshold) > 0.9:
400+
old_corr = pearson_r
401+
else:
402+
break
415403

416404
self.acc_rate = np.mean(ac_)
417405

@@ -429,36 +417,39 @@ def sample_settings(self):
429417
stats.update(
430418
{
431419
"_n_tune": self.n_steps, # Default property name used in `SamplerReport`
432-
"tune_steps": self.tune_steps,
433-
"p_acc_rate": self.p_acc_rate,
420+
"correlation_threshold": self.correlation_threshold,
434421
}
435422
)
436423
return stats
437424

438425

426+
class Pearson:
427+
def __init__(self, a):
428+
self.l = a.shape[0]
429+
self.am = a - np.sum(a, axis=0) / self.l
430+
self.aa = np.sum(self.am**2, axis=0) ** 0.5
431+
432+
def get(self, b):
433+
bm = b - np.sum(b, axis=0) / self.l
434+
bb = np.sum(bm**2, axis=0) ** 0.5
435+
ab = np.sum(self.am * bm, axis=0)
436+
return np.abs(ab / (self.aa * bb))
437+
438+
439439
class MH(SMC_KERNEL):
440440
"""Metropolis-Hastings SMC kernel"""
441441

442-
def __init__(self, *args, n_steps=25, tune_steps=True, p_acc_rate=0.85, **kwargs):
442+
def __init__(self, *args, correlation_threshold=0.01, **kwargs):
443443
"""
444444
Parameters
445445
----------
446-
n_steps: int
447-
The number of steps of each Markov Chain.
448-
tune_steps: bool
449-
Whether to compute the number of steps automatically or not. Defaults to True
450-
p_acc_rate: float
451-
Used to compute ``n_steps`` when ``tune_steps == True``. The higher the value of
452-
``p_acc_rate`` the higher the number of steps computed automatically. Defaults to 0.85.
453-
It should be between 0 and 1.
446+
correlation_threshold: float
447+
The lower the value the higher the number of MH steps computed automatically.
448+
Defaults to 0.01. It should be between 0 and 1.
454449
"""
455450
super().__init__(*args, **kwargs)
456-
self.n_steps = n_steps
457-
self.tune_steps = tune_steps
458-
self.p_acc_rate = p_acc_rate
451+
self.correlation_threshold = correlation_threshold
459452

460-
self.max_steps = n_steps
461-
self.proposed = self.draws * self.n_steps
462453
self.proposal_dist = None
463454
self.proposal_scales = None
464455
self.chain_acc_rate = None
@@ -484,14 +475,6 @@ def tune(self):
484475
# Interpolate between individual and population scales
485476
self.proposal_scales = 0.5 * (chain_scales + chain_scales.mean())
486477

487-
if self.tune_steps:
488-
acc_rate = max(1.0 / self.proposed, self.chain_acc_rate.mean())
489-
self.n_steps = min(
490-
self.max_steps,
491-
max(2, int(np.log(1 - self.p_acc_rate) / np.log(1 - acc_rate))),
492-
)
493-
self.proposed = self.draws * self.n_steps
494-
495478
# Update MVNormal proposal based on the covariance of the tempered posterior.
496479
cov = np.cov(self.tempered_posterior, ddof=0, rowvar=0)
497480
cov = np.atleast_2d(cov)
@@ -502,9 +485,12 @@ def tune(self):
502485

503486
def mutate(self):
504487
"""Metropolis-Hastings perturbation."""
505-
ac_ = np.empty((self.n_steps, self.draws))
506-
log_R = np.log(self.rng.random((self.n_steps, self.draws)))
507-
for n_step in range(self.n_steps):
488+
self.n_steps = 1
489+
old_corr = 2
490+
corr = Pearson(self.tempered_posterior)
491+
ac_ = []
492+
while True:
493+
log_R = np.log(self.rng.random(self.draws))
508494
proposal = floatX(
509495
self.tempered_posterior
510496
+ self.proposal_dist(num_draws=self.draws, rng=self.rng)
@@ -514,13 +500,20 @@ def mutate(self):
514500
pl = np.array([self.prior_logp_func(prop) for prop in proposal])
515501

516502
proposal_logp = pl + ll * self.beta
517-
accepted = log_R[n_step] < (proposal_logp - self.tempered_posterior_logp)
503+
accepted = log_R < (proposal_logp - self.tempered_posterior_logp)
518504

519-
ac_[n_step] = accepted
520505
self.tempered_posterior[accepted] = proposal[accepted]
521506
self.prior_logp[accepted] = pl[accepted]
522507
self.likelihood_logp[accepted] = ll[accepted]
523508
self.tempered_posterior_logp[accepted] = proposal_logp[accepted]
509+
ac_.append(accepted)
510+
self.n_steps += 1
511+
512+
pearson_r = corr.get(self.tempered_posterior)
513+
if np.mean((old_corr - pearson_r) > self.correlation_threshold) > 0.9:
514+
old_corr = pearson_r
515+
else:
516+
break
524517

525518
self.chain_acc_rate = np.mean(ac_, axis=0)
526519

@@ -539,8 +532,7 @@ def sample_settings(self):
539532
stats.update(
540533
{
541534
"_n_tune": self.n_steps, # Default property name used in `SamplerReport`
542-
"tune_steps": self.tune_steps,
543-
"p_acc_rate": self.p_acc_rate,
535+
"correlation_threshold": self.correlation_threshold,
544536
}
545537
)
546538
return stats

pymc/tests/test_smc.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,36 +165,29 @@ def test_kernel_kwargs(self):
165165
draws=10,
166166
chains=1,
167167
threshold=0.7,
168-
n_steps=15,
169-
tune_steps=False,
170-
p_acc_rate=0.5,
168+
correlation_threshold=0.02,
171169
return_inferencedata=False,
172170
kernel=pm.smc.IMH,
173171
)
174172

175173
assert trace.report.threshold == 0.7
176174
assert trace.report.n_draws == 10
177-
assert trace.report.n_tune == 15
178-
assert trace.report.tune_steps is False
179-
assert trace.report.p_acc_rate == 0.5
175+
176+
assert trace.report.correlation_threshold == 0.02
180177

181178
with self.fast_model:
182179
trace = pm.sample_smc(
183180
draws=10,
184181
chains=1,
185182
threshold=0.95,
186-
n_steps=15,
187-
tune_steps=False,
188-
p_acc_rate=0.5,
183+
correlation_threshold=0.02,
189184
return_inferencedata=False,
190185
kernel=pm.smc.MH,
191186
)
192187

193188
assert trace.report.threshold == 0.95
194189
assert trace.report.n_draws == 10
195-
assert trace.report.n_tune == 15
196-
assert trace.report.tune_steps is False
197-
assert trace.report.p_acc_rate == 0.5
190+
assert trace.report.correlation_threshold == 0.02
198191

199192
@pytest.mark.parametrize("chains", (1, 2))
200193
def test_return_datatype(self, chains):

0 commit comments

Comments
 (0)