Skip to content

Commit 2e35805

Browse files
committed
Merge pull request #799 from pymc-devs/gibbs
Add Binary Gibbs Metropolis sampler for mixture models.
2 parents d18a0aa + 15737a9 commit 2e35805

File tree

3 files changed

+88
-48
lines changed

3 files changed

+88
-48
lines changed

pymc3/sampling.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212
__all__ = ['sample', 'iter_sample', 'sample_ppc']
1313

1414
def assign_step_methods(model, step=None,
15-
methods=(NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
15+
methods=(NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, BinaryGibbsMetropolis,
1616
Slice, ElemwiseCategoricalStep)):
1717
'''
18-
Assign model variables to appropriate step methods. Passing a specified
18+
Assign model variables to appropriate step methods. Passing a specified
1919
model will auto-assign its constituent stochastic variables to step methods
2020
based on the characteristics of the variables. This function is intended to
2121
be called automatically from `sample()`, but may be called manually. Each
2222
step method passed should have a `competence()` method that returns an
2323
ordinal competence value corresponding to the variable passed to it. This
24-
value quantifies the appropriateness of the step method for sampling the
25-
variable.
26-
24+
value quantifies the appropriateness of the step method for sampling the
25+
variable.
26+
2727
Parameters
2828
----------
29-
29+
3030
model : Model object
3131
A fully-specified model object
3232
step : step function or vector of step functions
@@ -35,12 +35,12 @@ def assign_step_methods(model, step=None,
3535
methods : vector of step method classes
3636
The set of step methods from which the function may choose. Defaults
3737
to the main step methods provided by PyMC3.
38-
38+
3939
Returns
4040
-------
4141
List of step methods associated with the model's variables.
4242
'''
43-
43+
4444
steps = []
4545
assigned_vars = set()
4646
if step is not None:
@@ -51,26 +51,26 @@ def assign_step_methods(model, step=None,
5151
except AttributeError:
5252
for m in s.methods:
5353
assigned_vars = assigned_vars | set(m.vars)
54-
54+
5555
# Use competence classmethods to select step methods for remaining variables
5656
selected_steps = defaultdict(list)
5757
for var in model.free_RVs:
5858
if not var in assigned_vars:
59-
59+
6060
competences = {s:s._competence(var) for s in methods}
6161

6262
selected = max(competences.keys(), key=(lambda k: competences[k]))
63-
63+
6464
if model.verbose:
6565
print('Assigned {0} to {1}'.format(selected.__name__, var))
6666
selected_steps[selected].append(var)
67-
67+
6868
# Instantiate all selected step methods
6969
steps += [s(vars=selected_steps[s]) for s in selected_steps if selected_steps[s]]
70-
70+
7171
if len(steps)==1:
72-
steps = steps[0]
73-
72+
steps = steps[0]
73+
7474
return steps
7575

7676
def sample(draws, step=None, start=None, trace=None, chain=0, njobs=1, tune=None,
@@ -86,8 +86,8 @@ def sample(draws, step=None, start=None, trace=None, chain=0, njobs=1, tune=None
8686
draws : int
8787
The number of samples to draw
8888
step : function or iterable of functions
89-
A step function or collection of functions. If no step methods are
90-
specified, or are partially specified, they will be assigned
89+
A step function or collection of functions. If no step methods are
90+
specified, or are partially specified, they will be assigned
9191
automatically (defaults to None).
9292
start : dict
9393
Starting point in parameter space (or partial point)
@@ -120,7 +120,7 @@ def sample(draws, step=None, start=None, trace=None, chain=0, njobs=1, tune=None
120120
MultiTrace object with access to sampling values
121121
"""
122122
model = modelcontext(model)
123-
123+
124124
step = assign_step_methods(model, step)
125125

126126
if njobs is None:
@@ -184,7 +184,7 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
184184
185185
draws : int
186186
The number of samples to draw
187-
step : function
187+
step : function
188188
Step function
189189
start : dict
190190
Starting point in parameter space (or partial point)

pymc3/step_methods/metropolis.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ..theanof import make_shared_replacements, join_nonshared_inputs, CallableTensor
1414

1515

16-
__all__ = ['Metropolis', 'BinaryMetropolis', 'NormalProposal', 'CauchyProposal', 'LaplaceProposal', 'PoissonProposal', 'MultivariateNormalProposal']
16+
__all__ = ['Metropolis', 'BinaryMetropolis', 'BinaryGibbsMetropolis', 'NormalProposal', 'CauchyProposal', 'LaplaceProposal', 'PoissonProposal', 'MultivariateNormalProposal']
1717

1818
# Available proposal distributions for Metropolis
1919

@@ -71,7 +71,7 @@ class Metropolis(ArrayStepShared):
7171
7272
"""
7373
default_blocked = False
74-
74+
7575
def __init__(self, vars=None, S=None, proposal_dist=NormalProposal, scaling=1.,
7676
tune=True, tune_interval=100, model=None, **kwargs):
7777

@@ -110,6 +110,7 @@ def astep(self, q0):
110110
self.accepted = 0
111111

112112
delta = self.proposal_dist() * self.scaling
113+
113114
if self.any_discrete:
114115
if self.all_discrete:
115116
delta = round(delta, 0).astype(int)
@@ -121,16 +122,15 @@ def astep(self, q0):
121122
else:
122123
q = q0 + delta
123124

124-
125-
q_new = metrop_select(self.delta_logp(q,q0), q, q0)
125+
q_new = metrop_select(self.delta_logp(q, q0), q, q0)
126126

127127
if q_new is q:
128128
self.accepted += 1
129129

130130
self.steps_until_tune -= 1
131131

132132
return q_new
133-
133+
134134
@staticmethod
135135
def competence(var):
136136
if var.dtype in discrete_types:
@@ -179,7 +179,7 @@ def tune(scale, acc_rate):
179179

180180
class BinaryMetropolis(ArrayStep):
181181
"""Metropolis-Hastings optimized for binary variables"""
182-
182+
183183
def __init__(self, vars, scaling=1., tune=True, tune_interval=100, model=None):
184184

185185
model = modelcontext(model)
@@ -197,31 +197,71 @@ def __init__(self, vars, scaling=1., tune=True, tune_interval=100, model=None):
197197
super(BinaryMetropolis, self).__init__(vars, [model.fastlogp])
198198

199199
def astep(self, q0, logp):
200-
201200
# Convert adaptive_scale_factor to a jump probability
202201
p_jump = 1. - .5 ** self.scaling
203202

204203
rand_array = random(q0.shape)
205204
q = copy(q0)
206205
# Locations where switches occur, according to p_jump
207-
switch_locs = where(rand_array < p_jump)
206+
switch_locs = (rand_array < p_jump)
208207
q[switch_locs] = True - q[switch_locs]
209-
210208
q_new = metrop_select(logp(q) - logp(q0), q, q0)
211209

212210
return q_new
213-
211+
212+
@staticmethod
213+
def competence(var):
214+
'''
215+
BinaryMetropolis is only suitable for binary (bool)
216+
and Categorical variables with k=1.
217+
'''
218+
if isinstance(var.distribution, Bernoulli) or (var.dtype in bool_types):
219+
return Competence.compatible
220+
elif isinstance(var.distribution, Categorical) and (var.distribution.k == 2):
221+
return Competence.compatible
222+
return Competence.incompatible
223+
224+
class BinaryGibbsMetropolis(ArrayStep):
225+
"""Metropolis-Hastings optimized for binary variables"""
226+
227+
def __init__(self, vars, order='random', model=None):
228+
229+
model = modelcontext(model)
230+
231+
self.dim = sum(v.dsize for v in vars)
232+
self.order = order
233+
234+
if not all([v.dtype in discrete_types for v in vars]):
235+
raise ValueError(
236+
'All variables must be Bernoulli for BinaryGibbsMetropolis')
237+
238+
super(BinaryGibbsMetropolis, self).__init__(vars, [model.fastlogp])
239+
240+
def astep(self, q0, logp):
241+
order = list(range(self.dim))
242+
if self.order == 'random':
243+
np.random.shuffle(order)
244+
245+
q_prop = copy(q0)
246+
q_cur = copy(q0)
247+
248+
for idx in order:
249+
q_prop[idx] = True - q_prop[idx]
250+
q_cur = metrop_select(logp(q_prop) - logp(q_cur), q_prop, q_cur)
251+
q_prop = copy(q_cur)
252+
253+
return q_cur
254+
214255
@staticmethod
215256
def competence(var):
216257
'''
217-
BinaryMetropolis is only suitable for binary (bool)
258+
BinaryMetropolis is only suitable for binary (bool)
218259
and Categorical variables with k=1.
219260
'''
220261
if isinstance(var.distribution, Bernoulli) or (var.dtype in bool_types):
221262
return Competence.ideal
222-
if isinstance(var.distribution, Categorical):
223-
if var.distribution.k==2:
224-
return Competence.ideal
263+
elif isinstance(var.distribution, Categorical) and (var.distribution.k == 2):
264+
return Competence.ideal
225265
return Competence.incompatible
226266

227267
def delta_logp(logp, vars, shared):

pymc3/tests/test_step.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from scipy.stats.mstats import moment
55
from pymc3.sampling import assign_step_methods
66
from pymc3.model import Model
7-
from pymc3.step_methods import NUTS, BinaryMetropolis, Metropolis, Constant, ElemwiseCategoricalStep
7+
from pymc3.step_methods import NUTS, BinaryMetropolis, BinaryGibbsMetropolis, Metropolis, Constant, ElemwiseCategoricalStep
88
from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical
99
from numpy.testing import assert_almost_equal
1010

@@ -96,41 +96,41 @@ def test_step_discrete():
9696
yield check_stat, repr(st), h, var, stat, val, bound
9797

9898
def test_constant_step():
99-
99+
100100
with Model() as model:
101101
x = Normal('x', 0, 1)
102102
start = {'x':-1}
103103
tr = sample(10, step=Constant([x]), start=start)
104104
assert_almost_equal(tr['x'], start['x'], decimal=10)
105105

106106
def test_assign_step_methods():
107-
107+
108108
with Model() as model:
109109
x = Bernoulli('x', 0.5)
110110
steps = assign_step_methods(model, [])
111-
112-
assert isinstance(steps, BinaryMetropolis)
113-
111+
112+
assert isinstance(steps, BinaryGibbsMetropolis)
113+
114114
with Model() as model:
115115
x = Normal('x', 0, 1)
116116
steps = assign_step_methods(model, [])
117-
117+
118118
assert isinstance(steps, NUTS)
119-
119+
120120
with Model() as model:
121121
x = Categorical('x', np.array([0.25, 0.75]))
122122
steps = assign_step_methods(model, [])
123-
124-
assert isinstance(steps, BinaryMetropolis)
125-
123+
124+
assert isinstance(steps, BinaryGibbsMetropolis)
125+
126126
with Model() as model:
127127
x = Categorical('x', np.array([0.25, 0.70, 0.05]))
128128
steps = assign_step_methods(model, [])
129-
129+
130130
assert isinstance(steps, ElemwiseCategoricalStep)
131-
131+
132132
with Model() as model:
133133
x = Binomial('x', 10, 0.5)
134134
steps = assign_step_methods(model, [])
135-
136-
assert isinstance(steps, Metropolis)
135+
136+
assert isinstance(steps, Metropolis)

0 commit comments

Comments
 (0)