Skip to content

Commit 4952b8e

Browse files
aloctavodiatwiecki
authored andcommitted
Fix discrete sampling (#3240)
bijection only mapped continuous variables (not sure why) but it caused problems for discrete parameters in SMC.
1 parent 6ccc11b commit 4952b8e

File tree

4 files changed

+251
-232
lines changed

4 files changed

+251
-232
lines changed

pymc3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def isroot(self):
672672
@property
673673
@memoize(bound=True)
674674
def bijection(self):
675-
vars = inputvars(self.cont_vars)
675+
vars = inputvars(self.vars)
676676

677677
bij = DictToArrayBijection(ArrayOrdering(vars),
678678
self.test_point)

pymc3/step_methods/smc.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@
99

1010
from .arraystep import metrop_select
1111
from .metropolis import MultivariateNormalProposal
12-
from ..theanof import floatX, make_shared_replacements, join_nonshared_inputs
12+
from ..theanof import floatX, make_shared_replacements, join_nonshared_inputs, inputvars
1313
from ..model import modelcontext
1414
from ..backends.ndarray import NDArray
1515
from ..backends.base import MultiTrace
1616

1717

18-
__all__ = ['SMC', 'sample_smc']
18+
__all__ = ["SMC", "sample_smc"]
1919

20-
proposal_dists = {'MultivariateNormal': MultivariateNormalProposal}
20+
proposal_dists = {"MultivariateNormal": MultivariateNormalProposal}
2121

2222

23-
class SMC():
23+
class SMC:
2424
"""
2525
Sequential Monte Carlo step
2626
@@ -59,8 +59,15 @@ class SMC():
5959
%282007%29133:7%28816%29>`__
6060
"""
6161

62-
def __init__(self, n_steps=5, scaling=1., p_acc_rate=0.01, tune=True,
63-
proposal_name='MultivariateNormal', threshold=0.5):
62+
def __init__(
63+
self,
64+
n_steps=25,
65+
scaling=1.0,
66+
p_acc_rate=0.01,
67+
tune=True,
68+
proposal_name="MultivariateNormal",
69+
threshold=0.5,
70+
):
6471

6572
self.n_steps = n_steps
6673
self.scaling = scaling
@@ -98,15 +105,15 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
98105
stage = 0
99106
acc_rate = 1
100107
model.marginal_likelihood = 1
101-
variables = model.vars
108+
variables = inputvars(model.vars)
102109
discrete = np.concatenate([[v.dtype in pm.discrete_types] * (v.dsize or 1) for v in variables])
103110
any_discrete = discrete.any()
104111
all_discrete = discrete.all()
105112
shared = make_shared_replacements(variables, model)
106113
prior_logp = logp_forw([model.varlogpt], variables, shared)
107114
likelihood_logp = logp_forw([model.datalogpt], variables, shared)
108115

109-
pm._log.info('Sample initial stage: ...')
116+
pm._log.info("Sample initial stage: ...")
110117
posterior, var_info = _initial_population(draws, model, variables)
111118

112119
while beta < 1:
@@ -127,15 +134,18 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
127134
# acceptance rate
128135
if step.tune and stage > 0:
129136
if acc_rate == 0:
130-
acc_rate = 1. / step.n_steps
137+
acc_rate = 1.0 / step.n_steps
131138
step.scaling = _tune(acc_rate)
132139
step.n_steps = 1 + int(np.log(step.p_acc_rate) / np.log(1 - acc_rate))
133140

134-
pm._log.info('Stage: {:d} Beta: {:f} Steps: {:d} Acc: {:f}'.format(stage, beta,
135-
step.n_steps, acc_rate))
141+
pm._log.info(
142+
"Stage: {:d} Beta: {:f} Steps: {:d} Acc: {:f}".format(
143+
stage, beta, step.n_steps, acc_rate
144+
)
145+
)
136146
# Apply Metropolis kernel (mutation)
137-
proposed = 0.
138-
accepted = 0.
147+
proposed = 0.0
148+
accepted = 0.0
139149
priors = np.array([prior_logp(sample) for sample in posterior]).squeeze()
140150
tempered_post = priors + likelihoods * beta
141151
for draw in tqdm(range(draws), disable=not progressbar):
@@ -147,12 +157,12 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
147157

148158
if any_discrete:
149159
if all_discrete:
150-
delta = np.round(delta, 0).astype('int64')
151-
q_old = q_old.astype('int64')
152-
q_new = (q_old + delta).astype('int64')
160+
delta = np.round(delta, 0).astype("int64")
161+
q_old = q_old.astype("int64")
162+
q_new = (q_old + delta).astype("int64")
153163
else:
154164
delta[discrete] = np.round(delta[discrete], 0)
155-
q_new = (q_old + delta)
165+
q_new = floatX(q_old + delta)
156166
else:
157167
q_new = floatX(q_old + delta)
158168

@@ -163,12 +173,12 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
163173
accepted += accept
164174
posterior[draw] = q_old
165175
old_tempered_post = new_tempered_post
166-
proposed += 1.
176+
proposed += 1.0
167177

168178
acc_rate = accepted / proposed
169179
stage += 1
170180

171-
trace = _posterior_to_trace(posterior, model, var_info)
181+
trace = _posterior_to_trace(posterior, variables, model, var_info)
172182

173183
return trace
174184

@@ -177,6 +187,7 @@ def _initial_population(draws, model, variables):
177187
"""
178188
Create an initial population from the prior
179189
"""
190+
180191
population = []
181192
var_info = {}
182193
start = model.test_point
@@ -188,7 +199,7 @@ def _initial_population(draws, model, variables):
188199
point = pm.Point({v.name: init_rnd[v.name][i] for v in variables}, model=model)
189200
population.append(model.dict_to_array(point))
190201

191-
return np.array(population), var_info
202+
return np.array(floatX(population)), var_info
192203

193204

194205
def _calc_beta(beta, likelihoods, threshold=0.5):
@@ -219,11 +230,11 @@ def _calc_beta(beta, likelihoods, threshold=0.5):
219230
Partial marginal likelihood
220231
"""
221232
low_beta = old_beta = beta
222-
up_beta = 2.
233+
up_beta = 2.0
223234
rN = int(len(likelihoods) * threshold)
224235

225236
while up_beta - low_beta > 1e-6:
226-
new_beta = (low_beta + up_beta) / 2.
237+
new_beta = (low_beta + up_beta) / 2.0
227238
weights_un = np.exp((new_beta - old_beta) * (likelihoods - likelihoods.max()))
228239
weights = weights_un / np.sum(weights_un)
229240
ESS = int(1 / np.sum(weights ** 2))
@@ -241,11 +252,11 @@ def _calc_beta(beta, likelihoods, threshold=0.5):
241252
return new_beta, old_beta, weights, np.mean(sj)
242253

243254

244-
def _calc_covariance(posterior_array, weights):
255+
def _calc_covariance(posterior, weights):
245256
"""
246257
Calculate trace covariance matrix based on importance weights.
247258
"""
248-
cov = np.cov(np.squeeze(posterior_array), aweights=weights.ravel(), bias=False, rowvar=0)
259+
cov = np.cov(posterior, aweights=weights.ravel(), bias=False, rowvar=0)
249260
if np.isnan(cov).any() or np.isinf(cov).any():
250261
raise ValueError('Sample covariances not valid! Likely "chains" is too small!')
251262
return np.atleast_2d(cov)
@@ -264,18 +275,18 @@ def _tune(acc_rate):
264275
-------
265276
scaling: float
266277
"""
267-
# a and b after Muto & Beck 2008 .
268-
a = 1. / 9
269-
b = 8. / 9
278+
# a and b after Muto & Beck 2008.
279+
a = 1.0 / 9
280+
b = 8.0 / 9
270281
return (a + b * acc_rate) ** 2
271282

272283

273-
def _posterior_to_trace(posterior, model, var_info):
284+
def _posterior_to_trace(posterior, variables, model, var_info):
274285
"""
275286
Save results into a PyMC3 trace
276287
"""
277288
lenght_pos = len(posterior)
278-
varnames = [v.name for v in model.vars]
289+
varnames = [v.name for v in variables]
279290

280291
with model:
281292
strace = NDArray(model)
@@ -285,7 +296,7 @@ def _posterior_to_trace(posterior, model, var_info):
285296
size = 0
286297
for var in varnames:
287298
shape, new_size = var_info[var]
288-
value.append(posterior[i][size:size+new_size].reshape(shape))
299+
value.append(posterior[i][size : size + new_size].reshape(shape))
289300
size += new_size
290301
strace.record({k: v for k, v in zip(varnames, value)})
291302
return MultiTrace([strace])

pymc3/tests/test_smc.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ def test_sample(self):
4747
mu1d = np.abs(x).mean(axis=0)
4848
np.testing.assert_allclose(self.muref, mu1d, rtol=0., atol=0.03)
4949

50+
51+
def test_discrete_continuous(self):
52+
with pm.Model() as model:
53+
a = pm.Poisson('a', 5)
54+
b = pm.HalfNormal('b', 10)
55+
y = pm.Normal('y', a, b, observed=[1, 2, 3, 4])
56+
trace = pm.sample(step=pm.SMC())
57+
58+
5059
def test_ml(self):
5160
data = np.repeat([1, 0], [50, 50])
5261
marginals = []

0 commit comments

Comments
 (0)