Skip to content

Commit a208c0d

Browse files
committed
reduce number of logp evaluations
1 parent 879cb49 commit a208c0d

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

pymc3/smc/smc.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,12 @@ def sample_smc(
191191

192192
if parallel and cores > 1:
193193
pool = mp.Pool(processes=cores)
194+
results = pool.starmap(likelihood_logp, [(sample,) for sample in posterior])
195+
else:
196+
results = [likelihood_logp(sample) for sample in posterior]
197+
likelihoods = np.array(results).squeeze()
194198

195199
while beta < 1:
196-
if parallel and cores > 1:
197-
results = pool.starmap(likelihood_logp, [(sample,) for sample in posterior])
198-
else:
199-
results = [likelihood_logp(sample) for sample in posterior]
200-
likelihoods = np.array(results).squeeze()
201200
beta, old_beta, weights, sj = calc_beta(beta, likelihoods, threshold)
202201

203202
model.marginal_likelihood *= sj
@@ -238,16 +237,20 @@ def sample_smc(
238237
if parallel and cores > 1:
239238
results = pool.starmap(
240239
metrop_kernel,
241-
[(posterior[draw], tempered_logp[draw], *parameters) for draw in range(draws)],
240+
[
241+
(posterior[draw], tempered_logp[draw], likelihoods[draw], *parameters)
242+
for draw in range(draws)
243+
],
242244
)
243245
else:
244246
results = [
245-
metrop_kernel(posterior[draw], tempered_logp[draw], *parameters)
247+
metrop_kernel(posterior[draw], tempered_logp[draw], likelihoods[draw], *parameters)
246248
for draw in tqdm(range(draws), disable=not progressbar)
247249
]
248250

249-
posterior, acc_list = zip(*results)
251+
posterior, acc_list, likelihoods = zip(*results)
250252
posterior = np.array(posterior)
253+
likelihoods = np.array(likelihoods)
251254
acc_rate = sum(acc_list) / proposed
252255
stage += 1
253256

pymc3/smc/smc_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _posterior_to_trace(posterior, variables, model, var_info):
110110
def metrop_kernel(
111111
q_old,
112112
old_tempered_logp,
113+
old_likelihood,
113114
proposal,
114115
scaling,
115116
accepted,
@@ -147,9 +148,10 @@ def metrop_kernel(
147148
q_old, accept = metrop_select(new_tempered_logp - old_tempered_logp, q_new, q_old)
148149
if accept:
149150
accepted += 1
151+
old_likelihood = ll
150152
old_tempered_logp = new_tempered_logp
151153

152-
return q_old, accepted
154+
return q_old, accepted, old_likelihood
153155

154156

155157
def calc_beta(beta, likelihoods, threshold=0.5, psis=True):

0 commit comments

Comments
 (0)