Skip to content

Commit 4b3620c

Browse files
author
Junpeng Lao
authored
Merge pull request #2867 from junpenglao/fix_2866
fix BinaryGibbsMetropolis issue for p=.5
2 parents 3c18c4d + e3c4422 commit 4b3620c

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

pymc3/step_methods/metropolis.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -297,13 +297,31 @@ def competence(var):
297297

298298

299299
class BinaryGibbsMetropolis(ArrayStep):
300-
"""A Metropolis-within-Gibbs step method optimized for binary variables"""
300+
"""A Metropolis-within-Gibbs step method optimized for binary variables
301+
302+
Parameters
303+
----------
304+
vars : list
305+
List of variables for sampler
306+
order : list or 'random'
307+
List of integers indicating the Gibbs update order
308+
e.g., [0, 2, 1, ...]. Default is random
309+
transit_p : float
310+
The diagonal of the transition kernel. A value > .5 gives anticorrelated proposals,
311+
which resulting in more efficient antithetical sampling.
312+
model : PyMC Model
313+
Optional model for sampling step. Defaults to None (taken from context).
314+
315+
"""
301316
name = 'binary_gibbs_metropolis'
302317

303-
def __init__(self, vars, order='random', model=None):
318+
def __init__(self, vars, order='random', transit_p=.8, model=None):
304319

305320
model = pm.modelcontext(model)
306321

322+
# transition probabilities
323+
self.transit_p = transit_p
324+
307325
self.dim = sum(v.dsize for v in vars)
308326

309327
if order == 'random':
@@ -330,11 +348,14 @@ def astep(self, q0, logp):
330348
logp_curr = logp(q)
331349

332350
for idx in order:
333-
curr_val, q[idx] = q[idx], True - q[idx]
334-
logp_prop = logp(q)
335-
q[idx], accepted = metrop_select(logp_prop - logp_curr, q[idx], curr_val)
336-
if accepted:
337-
logp_curr = logp_prop
351+
# No need to do metropolis update if the same value is proposed,
352+
# as you will get the same value regardless of accepted or reject
353+
if nr.rand() < self.transit_p:
354+
curr_val, q[idx] = q[idx], True - q[idx]
355+
logp_prop = logp(q)
356+
q[idx], accepted = metrop_select(logp_prop - logp_curr, q[idx], curr_val)
357+
if accepted:
358+
logp_curr = logp_prop
338359

339360
return q
340361

0 commit comments

Comments
 (0)