Skip to content

Commit 879cb49

Browse files
authored
SMC: change scaling computation (#3594)
* change scaling computation * add release notes * remove comma
1 parent 0d5203c commit 879cb49

File tree

3 files changed

+32
-22
lines changed

3 files changed

+32
-22
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
- Fixed a defect in `OrderedLogistic.__init__` that unnecessarily increased the dimensionality of the underlying `p`. Related to issue issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535) but was not the true cause of it.
1818
- SMC: stabilize covariance matrix [3573](https://github.com/pymc-devs/pymc3/pull/3573)
1919
- SMC is no longer a step method of `pm.sample` now it should be called using `pm.sample_smc` [3579](https://github.com/pymc-devs/pymc3/pull/3579)
20+
- SMC: improve computation of the proposal scaling factor [3594](https://github.com/pymc-devs/pymc3/pull/3594)
2021
- Now uses `multiprocessong` rather than `psutil` to count CPUs, which results in reliable core counts on Chromebooks.
2122
- `sample_posterior_predictive` now preallocates the memory required for its output to improve memory usage. Addresses problems raised in this [discourse thread](https://discourse.pymc.io/t/memory-error-with-posterior-predictive-sample/2891/4).
2223
- Fixed a bug in `Categorical.logp`. In the case of multidimensional `p`'s, the indexing was done wrong leading to incorrectly shaped tensors that consumed `O(n**2)` memory instead of `O(n)`. This fixes issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535)

pymc3/smc/smc.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,17 @@ def sample_smc(
189189
elif kernel.lower() == "metropolis":
190190
likelihood_logp = logp_forw([model.datalogpt], variables, shared)
191191

192+
if parallel and cores > 1:
193+
pool = mp.Pool(processes=cores)
194+
192195
while beta < 1:
193196
if parallel and cores > 1:
194-
pool = mp.Pool(processes=cores)
195197
results = pool.starmap(likelihood_logp, [(sample,) for sample in posterior])
196198
else:
197199
results = [likelihood_logp(sample) for sample in posterior]
198200
likelihoods = np.array(results).squeeze()
199201
beta, old_beta, weights, sj = calc_beta(beta, likelihoods, threshold)
202+
200203
model.marginal_likelihood *= sj
201204
# resample based on plausibility weights (selection)
202205
resampling_indexes = np.random.choice(np.arange(draws), size=draws, p=weights)
@@ -211,22 +214,14 @@ def sample_smc(
211214
# acceptance rate of the previous stage
212215
if (tune_scaling or tune_steps) and stage > 0:
213216
scaling, n_steps = _tune(
214-
acc_rate,
215-
proposed,
216-
tune_scaling,
217-
tune_steps,
218-
scaling,
219-
n_steps,
220-
max_steps,
221-
p_acc_rate,
217+
acc_rate, proposed, tune_scaling, tune_steps, scaling, max_steps, p_acc_rate
222218
)
223219

224-
pm._log.info("Stage: {:d} Beta: {:.3f} Steps: {:d}".format(stage, beta, n_steps))
220+
pm._log.info("Stage: {:3d} Beta: {:.3f} Steps: {:3d}".format(stage, beta, n_steps))
225221
# Apply Metropolis kernel (mutation)
226222
proposed = draws * n_steps
227223
priors = np.array([prior_logp(sample) for sample in posterior]).squeeze()
228224
tempered_logp = priors + likelihoods * beta
229-
deltas = np.squeeze(proposal(n_steps) * scaling)
230225

231226
parameters = (
232227
proposal,
@@ -240,9 +235,7 @@ def sample_smc(
240235
likelihood_logp,
241236
beta,
242237
)
243-
244238
if parallel and cores > 1:
245-
pool = mp.Pool(processes=cores)
246239
results = pool.starmap(
247240
metrop_kernel,
248241
[(posterior[draw], tempered_logp[draw], *parameters) for draw in range(draws)],
@@ -258,6 +251,9 @@ def sample_smc(
258251
acc_rate = sum(acc_list) / proposed
259252
stage += 1
260253

254+
if parallel and cores > 1:
255+
pool.close()
256+
pool.join()
261257
trace = _posterior_to_trace(posterior, variables, model, var_info)
262258

263259
return trace

pymc3/smc/smc_utils.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pymc3 as pm
77
import theano
88
from ..step_methods.arraystep import metrop_select
9+
from ..step_methods.metropolis import tune
910
from ..backends.ndarray import NDArray
1011
from ..backends.base import MultiTrace
1112
from ..theanof import floatX, join_nonshared_inputs
@@ -51,7 +52,7 @@ def _calc_covariance(posterior, weights):
5152
return cov
5253

5354

54-
def _tune(acc_rate, proposed, tune_scaling, tune_steps, scaling, n_steps, max_steps, p_acc_rate):
55+
def _tune(acc_rate, proposed, tune_scaling, tune_steps, scaling, max_steps, p_acc_rate):
5556
"""
5657
Tune scaling and/or n_steps based on the acceptance rate.
5758
@@ -61,17 +62,26 @@ def _tune(acc_rate, proposed, tune_scaling, tune_steps, scaling, n_steps, max_st
6162
Acceptance rate of the previous stage
6263
proposed: int
6364
Total number of proposed steps (draws * n_steps)
64-
step: SMC step method
65+
tune_scaling : bool
66+
Whether to compute the scaling factor automatically or not
67+
tune_steps : bool
68+
Whether to compute the number of steps automatically or not
69+
scaling : float
70+
Scaling factor applied to the proposal distribution
71+
max_steps : int
72+
The maximum number of steps of each Markov Chain.
73+
p_acc_rate : float
74+
The higher the value of the higher the number of steps computed automatically. It should be
75+
between 0 and 1.
6576
"""
6677
if tune_scaling:
67-
# a and b after Muto & Beck 2008.
68-
a = 1 / 9
69-
b = 8 / 9
70-
scaling = (a + b * acc_rate) ** 2
78+
scaling = tune(scaling, acc_rate)
7179

7280
if tune_steps:
7381
acc_rate = max(1.0 / proposed, acc_rate)
7482
n_steps = min(max_steps, max(2, int(np.log(1 - p_acc_rate) / np.log(1 - acc_rate))))
83+
else:
84+
n_steps = max_steps
7585

7686
return scaling, n_steps
7787

@@ -115,6 +125,7 @@ def metrop_kernel(
115125
Metropolis kernel
116126
"""
117127
deltas = np.squeeze(proposal(n_steps) * scaling)
128+
118129
for n_step in range(n_steps):
119130
delta = deltas[n_step]
120131

@@ -141,7 +152,7 @@ def metrop_kernel(
141152
return q_old, accepted
142153

143154

144-
def calc_beta(beta, likelihoods, threshold=0.5):
155+
def calc_beta(beta, likelihoods, threshold=0.5, psis=True):
145156
"""
146157
Calculate next inverse temperature (beta) and importance weights based on current beta
147158
and tempered likelihood.
@@ -185,9 +196,11 @@ def calc_beta(beta, likelihoods, threshold=0.5):
185196
low_beta = new_beta
186197
if new_beta >= 1:
187198
new_beta = 1
199+
weights_un = np.exp((new_beta - old_beta) * (likelihoods - likelihoods.max()))
200+
weights = weights_un / np.sum(weights_un)
201+
188202
sj = np.exp((new_beta - old_beta) * likelihoods)
189-
weights_un = np.exp((new_beta - old_beta) * (likelihoods - likelihoods.max()))
190-
weights = weights_un / np.sum(weights_un)
203+
191204
return new_beta, old_beta, weights, np.mean(sj)
192205

193206

0 commit comments

Comments
 (0)