Skip to content

Commit 219652c

Browse files
authored
reduce number of logp evaluations II (#3601)
* reduce number of logp evaluations II * floatX
1 parent b499c73 commit 219652c

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

pymc3/smc/smc.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,14 @@ def sample_smc(
191191

192192
if parallel and cores > 1:
193193
pool = mp.Pool(processes=cores)
194-
results = pool.starmap(likelihood_logp, [(sample,) for sample in posterior])
194+
priors = pool.starmap(prior_logp, [(sample,) for sample in posterior])
195+
likelihoods = pool.starmap(likelihood_logp, [(sample,) for sample in posterior])
195196
else:
196-
results = [likelihood_logp(sample) for sample in posterior]
197-
likelihoods = np.array(results).squeeze()
197+
priors = [prior_logp(sample) for sample in posterior]
198+
likelihoods = [likelihood_logp(sample) for sample in posterior]
199+
200+
priors = np.array(priors).squeeze()
201+
likelihoods = np.array(likelihoods).squeeze()
198202

199203
while beta < 1:
200204
beta, old_beta, weights, sj = calc_beta(beta, likelihoods, threshold)
@@ -203,6 +207,7 @@ def sample_smc(
203207
# resample based on plausibility weights (selection)
204208
resampling_indexes = np.random.choice(np.arange(draws), size=draws, p=weights)
205209
posterior = posterior[resampling_indexes]
210+
priors = priors[resampling_indexes]
206211
likelihoods = likelihoods[resampling_indexes]
207212

208213
# compute proposal distribution based on weights
@@ -219,7 +224,6 @@ def sample_smc(
219224
pm._log.info("Stage: {:3d} Beta: {:.3f} Steps: {:3d}".format(stage, beta, n_steps))
220225
# Apply Metropolis kernel (mutation)
221226
proposed = draws * n_steps
222-
priors = np.array([prior_logp(sample) for sample in posterior]).squeeze()
223227
tempered_logp = priors + likelihoods * beta
224228

225229
parameters = (
@@ -238,18 +242,31 @@ def sample_smc(
238242
results = pool.starmap(
239243
metrop_kernel,
240244
[
241-
(posterior[draw], tempered_logp[draw], likelihoods[draw], *parameters)
245+
(
246+
posterior[draw],
247+
tempered_logp[draw],
248+
priors[draw],
249+
likelihoods[draw],
250+
*parameters,
251+
)
242252
for draw in range(draws)
243253
],
244254
)
245255
else:
246256
results = [
247-
metrop_kernel(posterior[draw], tempered_logp[draw], likelihoods[draw], *parameters)
257+
metrop_kernel(
258+
posterior[draw],
259+
tempered_logp[draw],
260+
priors[draw],
261+
likelihoods[draw],
262+
*parameters
263+
)
248264
for draw in tqdm(range(draws), disable=not progressbar)
249265
]
250266

251-
posterior, acc_list, likelihoods = zip(*results)
267+
posterior, acc_list, priors, likelihoods = zip(*results)
252268
posterior = np.array(posterior)
269+
priors = np.array(priors)
253270
likelihoods = np.array(likelihoods)
254271
acc_rate = sum(acc_list) / proposed
255272
stage += 1

pymc3/smc/smc_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _posterior_to_trace(posterior, variables, model, var_info):
110110
def metrop_kernel(
111111
q_old,
112112
old_tempered_logp,
113+
old_prior,
113114
old_likelihood,
114115
proposal,
115116
scaling,
@@ -142,16 +143,18 @@ def metrop_kernel(
142143
q_new = floatX(q_old + delta)
143144

144145
ll = likelihood_logp(q_new)
146+
pl = prior_logp(q_new)
145147

146-
new_tempered_logp = prior_logp(q_new) + ll * beta
148+
new_tempered_logp = pl + ll * beta
147149

148150
q_old, accept = metrop_select(new_tempered_logp - old_tempered_logp, q_new, q_old)
149151
if accept:
150152
accepted += 1
153+
old_prior = pl
151154
old_likelihood = ll
152155
old_tempered_logp = new_tempered_logp
153156

154-
return q_old, accepted, old_likelihood
157+
return q_old, accepted, old_prior, old_likelihood
155158

156159

157160
def calc_beta(beta, likelihoods, threshold=0.5, psis=True):

0 commit comments

Comments
 (0)