Skip to content

Commit ff0156d

Browse files
authored
update SMC v4 (#334)
* update SMC v4 * use initval * remove potential warnings
1 parent e88675c commit ff0156d

File tree

2 files changed

+178
-148
lines changed

2 files changed

+178
-148
lines changed

examples/samplers/SMC2_gaussians.ipynb

Lines changed: 150 additions & 116 deletions
Large diffs are not rendered by default.

myst_nbs/samplers/SMC2_gaussians.myst.md

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,23 @@ jupytext:
66
format_version: 0.13
77
jupytext_version: 1.13.7
88
kernelspec:
9-
display_name: Python 3 (ipykernel)
9+
display_name: Python 3.9.7 ('base')
1010
language: python
1111
name: python3
1212
---
1313

1414
# Sequential Monte Carlo
1515

1616
:::{post} Oct 19, 2021
17-
:tags: SMC, pymc3.Model, pymc3.Potential, pymc3.Uniform, pymc3.sample_smc
17+
:tags: SMC
1818
:category: beginner
1919
:::
2020

2121
```{code-cell} ipython3
22-
%matplotlib inline
22+
import aesara.tensor as at
2323
import arviz as az
24-
import matplotlib.pyplot as plt
2524
import numpy as np
26-
import pymc3 as pm
27-
import theano.tensor as tt
25+
import pymc as pm
2826
2927
print(f"Running on PyMC v{pm.__version__}")
3028
```
@@ -46,7 +44,7 @@ When $\beta=0$ we have that $p(\theta \mid y)_{\beta=0}$ is the prior distributi
4644
A summary of the algorithm is:
4745

4846
1. Initialize $\beta$ at zero and stage at zero.
49-
2. Generate N samples $S_\text{\beta}$ from the prior (because when $\beta = 0$ the tempered posterior is the prior).
47+
2. Generate N samples $S_{\beta}$ from the prior (because when $\beta = 0$ the tempered posterior is the prior).
5048
3. Increase $\beta$ in order to make the effective sample size equals some predefined value (we use $Nt$, where $t$ is 0.5 by default).
5149
4. Compute a set of N importance weights $W$. The weights are computed as the ratio of the likelihoods of a sample at stage $i+1$ and stage $i$.
5250
5. Obtain $S_{w}$ by re-sampling according to $W$.
@@ -110,16 +108,16 @@ w2 = 1 - w1 # the other mode with 0.9 of the mass
110108
111109
def two_gaussians(x):
112110
log_like1 = (
113-
-0.5 * n * tt.log(2 * np.pi)
114-
- 0.5 * tt.log(dsigma)
111+
-0.5 * n * at.log(2 * np.pi)
112+
- 0.5 * at.log(dsigma)
115113
- 0.5 * (x - mu1).T.dot(isigma).dot(x - mu1)
116114
)
117115
log_like2 = (
118-
-0.5 * n * tt.log(2 * np.pi)
119-
- 0.5 * tt.log(dsigma)
116+
-0.5 * n * at.log(2 * np.pi)
117+
- 0.5 * at.log(dsigma)
120118
- 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
121119
)
122-
return pm.math.logsumexp([tt.log(w1) + log_like1, tt.log(w2) + log_like2])
120+
return pm.math.logsumexp([at.log(w1) + log_like1, at.log(w2) + log_like2])
123121
```
124122

125123
```{code-cell} ipython3
@@ -129,11 +127,10 @@ with pm.Model() as model:
129127
shape=n,
130128
lower=-2.0 * np.ones_like(mu1),
131129
upper=2.0 * np.ones_like(mu1),
132-
testval=-1.0 * np.ones_like(mu1),
130+
initval=-1.0 * np.ones_like(mu1),
133131
)
134132
llk = pm.Potential("llk", two_gaussians(X))
135-
trace_04 = pm.sample_smc(2000, parallel=True)
136-
idata_04 = az.from_pymc3(trace_04)
133+
idata_04 = pm.sample_smc(2000)
137134
```
138135

139136
We can see from the message that PyMC is running four **SMC chains** in parallel. As explained before this is useful for diagnostics. As with other samplers one useful diagnostics is the `plot_trace`, here we use `kind="rank_vlines"` as rank plots as generally more useful than the classical "trace"
@@ -142,7 +139,7 @@ We can see from the message that PyMC is running four **SMC chains** in parallel
142139
ax = az.plot_trace(idata_04, compact=True, kind="rank_vlines")
143140
ax[0, 0].axvline(-0.5, 0, 0.9, color="k")
144141
ax[0, 0].axvline(0.5, 0, 0.1, color="k")
145-
f'Estimated w1 = {np.mean(idata_04.posterior["X"] > 0).item():.3f}'
142+
f'Estimated w1 = {np.mean(idata_04.posterior["X"] < 0).item():.3f}'
146143
```
147144

148145
From the KDE we can see that we recover the modes and even the relative weights seems pretty good. The rank plot on the right looks good too. One SMC chain is represented in blue and the other in orange. The vertical lines indicate deviation from the ideal expected value, which is represented with a black dashed line. If a vertical line is above the reference black dashed line we have more samples than expected, if the vertical line is below the sampler is getting less samples than expected. Deviations like the ones in the figure above are fine and not a reason for concern.
@@ -155,7 +152,7 @@ As previously said SMC internally computes an estimation of the ESS (from import
155152

156153
SMC is not free of problems, sampling can deteriorate as the dimensionality of the problem increases, in particular for multimodal posterior or _weird_ geometries as in hierarchical models. To some extent increasing the number of draws could help. Increasing the value of the argument `p_acc_rate` is also a good idea. This parameter controls how the number of steps is computed at each stage. To access the number of steps per stage you can check `trace.report.nsteps`. Ideally SMC will take a number of steps lower than `n_steps`. But if the actual number of steps per stage is `n_steps`, for a few stages, this may be signaling that we should also increase `n_steps`.
157154

158-
Let's see the performance of SMC when we run the same model as before, but increasing the dimensionality from 4 to 80.
155+
Let's see the performance of SMC when we run the same model as before, but increasing the dimensionality from 4 to 80.
159156

160157
```{code-cell} ipython3
161158
n = 80
@@ -168,45 +165,44 @@ sigma = np.power(stdev, 2) * np.eye(n)
168165
isigma = np.linalg.inv(sigma)
169166
dsigma = np.linalg.det(sigma)
170167
171-
w1 = 0.1
172-
w2 = 1 - w1
173-
```
168+
w1 = 0.1 # one mode with 0.1 of the mass
169+
w2 = 1 - w1 # the other mode with 0.9 of the mass
170+
174171
175-
```{code-cell} ipython3
176172
def two_gaussians(x):
177173
log_like1 = (
178-
-0.5 * n * tt.log(2 * np.pi)
179-
- 0.5 * tt.log(dsigma)
174+
-0.5 * n * at.log(2 * np.pi)
175+
- 0.5 * at.log(dsigma)
180176
- 0.5 * (x - mu1).T.dot(isigma).dot(x - mu1)
181177
)
182178
log_like2 = (
183-
-0.5 * n * tt.log(2 * np.pi)
184-
- 0.5 * tt.log(dsigma)
179+
-0.5 * n * at.log(2 * np.pi)
180+
- 0.5 * at.log(dsigma)
185181
- 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
186182
)
187-
return pm.math.logsumexp([tt.log(w1) + log_like1, tt.log(w2) + log_like2])
188-
183+
return pm.math.logsumexp([at.log(w1) + log_like1, at.log(w2) + log_like2])
184+
```
189185

186+
```{code-cell} ipython3
190187
with pm.Model() as model:
191188
X = pm.Uniform(
192189
"X",
193190
shape=n,
194191
lower=-2.0 * np.ones_like(mu1),
195192
upper=2.0 * np.ones_like(mu1),
196-
testval=-1.0 * np.ones_like(mu1),
193+
initval=-1.0 * np.ones_like(mu1),
197194
)
198195
llk = pm.Potential("llk", two_gaussians(X))
199-
trace_80 = pm.sample_smc(2000, parallel=True)
200-
idata_80 = az.from_pymc3(trace_80)
196+
idata_80 = pm.sample_smc(2000)
201197
```
202198

203-
We see that SMC recognizes this is a harder problem and increases the number of stages. We can see that SMC still sample from both modes but now the model with less weight is being subsampled (we get a relative weight way lower than 0.1). Notice how the rank plot looks worse than when n=4.
199+
We see that SMC recognizes this is a harder problem and increases the number of stages. We can see that SMC still sample from both modes but now the model with higher weight is being oversampled (we get a relative weight of 0.99 instead of 0.9). Notice how the rank plot looks worse than when n=4.
204200

205201
```{code-cell} ipython3
206202
ax = az.plot_trace(idata_80, compact=True, kind="rank_vlines")
207203
ax[0, 0].axvline(-0.5, 0, 0.9, color="k")
208204
ax[0, 0].axvline(0.5, 0, 0.1, color="k")
209-
f'Estimated w1 = {np.mean(idata_80.posterior["X"] > 0).item():.3f}'
205+
f'Estimated w1 = {np.mean(idata_80.posterior["X"] < 0).item():.3f}'
210206
```
211207

212208
You may want to repeat the SMC sampling for n=80, and change one or more of the default parameters too see if you can improve the sampling and how much time the sampler takes to compute the posterior.

0 commit comments

Comments
 (0)