Skip to content

Commit b2be720

Browse files
ColCarrolltwiecki
authored andcommitted
metrop_accept returns whether the sample was accepted
1 parent 6c9e848 commit b2be720

File tree

4 files changed

+18
-19
lines changed

4 files changed

+18
-19
lines changed

pymc3/step_methods/arraystep.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def metrop_select(mr, q, q0):
159159
"""Perform rejection/acceptance step for Metropolis class samplers.
160160
161161
Returns the new sample q if a uniform random number is less than the
162-
metropolis acceptance rate (`mr`), and the old sample otherwise.
162+
metropolis acceptance rate (`mr`), and the old sample otherwise, along
163+
with a boolean indicating whether the sample was accepted.
163164
164165
Parameters
165166
----------
@@ -173,6 +174,6 @@ def metrop_select(mr, q, q0):
173174
"""
174175
# Compare acceptance ratio to uniform random number
175176
if np.isfinite(mr) and np.log(uniform()) < mr:
176-
return q
177+
return q, True
177178
else:
178-
return q0
179+
return q0, False

pymc3/step_methods/hmc/hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def astep(self, q0):
6464
initial_energy = self.compute_energy(q, p)
6565
q, p, current_energy = self.leapfrog(q, p, e, n_steps)
6666
energy_change = initial_energy - current_energy
67-
return metrop_select(energy_change, q, q0)
67+
return metrop_select(energy_change, q, q0)[0]
6868

6969
@staticmethod
7070
def competence(var):

pymc3/step_methods/metropolis.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,8 @@ def astep(self, q0):
150150
q = floatX(q0 + delta)
151151

152152
accept = self.delta_logp(q, q0)
153-
q_new = metrop_select(accept, q, q0)
154-
155-
if q_new is q:
156-
self.accepted += 1
153+
q_new, accepted = metrop_select(accept, q, q0)
154+
self.accepted += accepted
157155

158156
self.steps_until_tune -= 1
159157

@@ -264,7 +262,8 @@ def astep(self, q0, logp):
264262
q[switch_locs] = True - q[switch_locs]
265263

266264
accept = logp(q) - logp(q0)
267-
q_new = metrop_select(accept, q, q0)
265+
q_new, accepted = metrop_select(accept, q, q0)
266+
self.accepted += accepted
268267

269268
stats = {
270269
'tune': self.tune,
@@ -325,8 +324,8 @@ def astep(self, q0, logp):
325324
for idx in order:
326325
curr_val, q[idx] = q[idx], True - q[idx]
327326
logp_prop = logp(q)
328-
q[idx] = metrop_select(logp_prop - logp_curr, q[idx], curr_val)
329-
if q[idx] != curr_val:
327+
q[idx], accepted = metrop_select(logp_prop - logp_curr, q[idx], curr_val)
328+
if accepted:
330329
logp_curr = logp_prop
331330

332331
return q
@@ -408,10 +407,9 @@ def astep_unif(self, q0, logp):
408407
for dim, k in dimcats:
409408
curr_val, q[dim] = q[dim], sample_except(k, q[dim])
410409
logp_prop = logp(q)
411-
q[dim] = metrop_select(logp_prop - logp_curr, q[dim], curr_val)
412-
if q[dim] != curr_val:
410+
q[dim], accepted = metrop_select(logp_prop - logp_curr, q[dim], curr_val)
411+
if accepted:
413412
logp_curr = logp_prop
414-
415413
return q
416414

417415
def astep_prop(self, q0, logp):

pymc3/step_methods/smc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .arraystep import metrop_select
3232
from ..backends import smc_text as atext
3333

34-
__all__ = ('SMC', 'ATMIP_sample')
34+
__all__ = ['SMC', 'ATMIP_sample']
3535

3636
EXPERIMENTAL_WARNING = "Warning: SMC is an experimental step method, and not yet"\
3737
" recommended for use in PyMC3!"
@@ -242,10 +242,10 @@ def astep(self, q0):
242242

243243
if np.isfinite(varlogp):
244244
logp = self.logp_forw(q)
245-
q_new = metrop_select(
245+
q_new, accepted = metrop_select(
246246
self.beta * (logp[self._llk_index] - l0[self._llk_index]), q, q0)
247247

248-
if q_new is q:
248+
if accepted:
249249
self.accepted += 1
250250
l_new = logp
251251
self.chain_previous_lpoint[self.chain_index] = l_new
@@ -257,10 +257,10 @@ def astep(self, q0):
257257

258258
else:
259259
logp = self.logp_forw(q)
260-
q_new = metrop_select(
260+
q_new, accepted = metrop_select(
261261
self.beta * (logp[self._llk_index] - l0[self._llk_index]), q, q0)
262262

263-
if q_new is q:
263+
if accepted:
264264
self.accepted += 1
265265
l_new = logp
266266
self.chain_previous_lpoint[self.chain_index] = l_new

0 commit comments

Comments
 (0)