Skip to content

Commit 9e8975f

Browse files
authored
SMC-ABC add distance, refactor and update notebook (#3996)
* update notebook * move dist functions out of simulator class * fix docstring * add warning and test for automatic selection of sort sum_stat when using wassertein and energy distances * update release notes * fix typo * add sim_data test * update and add tests * update and add tests
1 parent 692a09f commit 9e8975f

File tree

6 files changed

+393
-278
lines changed

6 files changed

+393
-278
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- Introduce optional arguments to `pm.sample`: `mp_ctx` to control how the processes for parallel sampling are started, and `pickle_backend` to specify which library is used to pickle models in parallel sampling when the multiprocessing cnotext is not of type `fork`. (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991))
1313
- Add sampler stats `process_time_diff`, `perf_counter_diff` and `perf_counter_start`, that record wall and CPU times for each NUTS and HMC sample (see [ #3986](https://github.com/pymc-devs/pymc3/pull/3986)).
1414
- Extend `keep_size` argument handling for `sample_posterior_predictive` and `fast_sample_posterior_predictive`, to work on arviz InferenceData and xarray Dataset input values. (see [PR #4006](https://github.com/pymc-devs/pymc3/pull/4006) and [Issue #4004](https://github.com/pymc-devs/pymc3/issues/4004).
15+
- SMC-ABC: add the wasserstein and energy distance functions. Refactor API, the distance, sum_stats and epsilon arguments are now passed `pm.Simulator` instead of `pm.sample_smc`. Add random method to `pm.Simulator`. Add option to save the simulated data. Improves LaTeX representation [#3996](https://github.com/pymc-devs/pymc3/pull/3996)
1516

1617
## PyMC3 3.9.2 (24 June 2020)
1718
### Maintenance

docs/source/notebooks/SMC-ABC_Lotka-Volterra_example.ipynb

Lines changed: 138 additions & 201 deletions
Large diffs are not rendered by default.

pymc3/distributions/simulator.py

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,86 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
import numpy as np
16-
from .distribution import NoDistribution
17+
from .distribution import NoDistribution, draw_values
1718

1819
__all__ = ["Simulator"]
1920

21+
_log = logging.getLogger("pymc3")
22+
2023

2124
class Simulator(NoDistribution):
22-
def __init__(self, function, *args, params=None, **kwargs):
25+
def __init__(
26+
self,
27+
function,
28+
*args,
29+
params=None,
30+
distance="gaussian_kernel",
31+
sum_stat="identity",
32+
epsilon=1,
33+
**kwargs,
34+
):
2335
"""
24-
This class stores a function defined by the user in python language.
36+
This class stores a function defined by the user in Python language.
2537
2638
function: function
27-
Simulation function defined by the user.
39+
Python function defined by the user.
2840
params: list
2941
Parameters passed to function.
42+
distance: str or callable
43+
Distance functions. Available options are "gaussian_kernel" (default), "wasserstein",
44+
"energy" or a user defined function that takes epsilon (a scalar), and the summary
45+
statistics of observed_data, and simulated_data as input.
46+
``gaussian_kernel`` :math: `\sum \left(-0.5 \left(\frac{xo - xs}{\epsilon}\right)^2\right)`
47+
``wasserstein`` :math: `\frac{1}{n} \sum{\left(\frac{|xo - xs|}{\epsilon}\right)}`
48+
``energy`` :math: `\sqrt{2} \sqrt{\frac{1}{n} \sum \left(\frac{|xo - xs|}{\epsilon}\right)^2}`
49+
For the wasserstein and energy distances the observed data xo and simulated data xs
50+
are internally sorted (i.e. the sum_stat is "sort").
51+
sum_stat: str or callable
52+
Summary statistics. Available options are ``indentity``, ``sort``, ``mean``, ``median``.
53+
If a callable is based it should return a number or a 1d numpy array.
54+
epsilon: float
55+
Standard deviation of the gaussian_kernel.
3056
*args and **kwargs:
3157
Arguments and keywords arguments that the function takes.
3258
"""
3359

3460
self.function = function
3561
self.params = params
3662
observed = self.data
63+
self.epsilon = epsilon
64+
65+
if distance == "gaussian_kernel":
66+
self.distance = gaussian_kernel
67+
elif distance == "wasserstein":
68+
self.distance = wasserstein
69+
if sum_stat != "sort":
70+
_log.info(f"Automatically setting sum_stat to sort as expected by {distance}")
71+
sum_stat = "sort"
72+
elif distance == "energy":
73+
self.distance = energy
74+
if sum_stat != "sort":
75+
_log.info(f"Automatically setting sum_stat to sort as expected by {distance}")
76+
sum_stat = "sort"
77+
elif hasattr(distance, "__call__"):
78+
self.distance = distance
79+
else:
80+
raise ValueError(f"The distance metric {distance} is not implemented")
81+
82+
if sum_stat == "identity":
83+
self.sum_stat = identity
84+
elif sum_stat == "sort":
85+
self.sum_stat = np.sort
86+
elif sum_stat == "mean":
87+
self.sum_stat = np.mean
88+
elif sum_stat == "median":
89+
self.sum_stat = np.median
90+
elif hasattr(sum_stat, "__call__"):
91+
self.sum_stat = sum_stat
92+
else:
93+
raise ValueError(f"The summary statistics {sum_stat} is not implemented")
94+
3795
super().__init__(shape=np.prod(observed.shape), dtype=observed.dtype, *args, **kwargs)
3896

3997
def random(self, point=None, size=None):
@@ -51,16 +109,44 @@ def random(self, point=None, size=None):
51109
-------
52110
array
53111
"""
54-
55-
raise NotImplementedError("Not implemented yet")
112+
params = draw_values([*self.params], point=point, size=size)
113+
if size is None:
114+
return self.function(*params)
115+
else:
116+
return np.array([self.function(*params) for _ in range(size)])
56117

57118
def _repr_latex_(self, name=None, dist=None):
58119
if dist is None:
59120
dist = self
60-
name = r"\text{%s}" % name
61-
function = dist.function
62-
params = dist.parameters
63-
sum_stat = dist.sum_stat
64-
return r"${} \sim \text{{Simulator}}(\mathit{{function}}={},~\mathit{{parameters}}={},~\mathit{{summary statistics}}={})$".format(
65-
name, function, params, sum_stat
66-
)
121+
name = name
122+
function = dist.function.__name__
123+
params = ", ".join([var.name for var in dist.params])
124+
sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat
125+
distance = self.distance.__name__
126+
return f"$\\text{{{name}}} \sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
127+
128+
129+
def identity(x):
130+
"""Identity function, used as a summary statistics."""
131+
return x
132+
133+
134+
def gaussian_kernel(epsilon, obs_data, sim_data):
135+
"""gaussian distance function"""
136+
return np.sum(-0.5 * ((obs_data - sim_data) / epsilon) ** 2)
137+
138+
139+
def wasserstein(epsilon, obs_data, sim_data):
140+
"""Wasserstein distance function.
141+
142+
We are assuming obs_data and sim_data are already sorted!
143+
"""
144+
return np.mean(np.abs((obs_data - sim_data) / epsilon))
145+
146+
147+
def energy(epsilon, obs_data, sim_data):
148+
"""Energy distance function.
149+
150+
We are assuming obs_data and sim_data are already sorted!
151+
"""
152+
return 1.4142 * np.mean(((obs_data - sim_data) / epsilon) ** 2) ** 0.5

pymc3/smc/sample_smc.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ def sample_smc(
3737
tune_steps=True,
3838
p_acc_rate=0.99,
3939
threshold=0.5,
40-
epsilon=1.0,
41-
dist_func="gaussian_kernel",
42-
sum_stat="identity",
40+
save_sim_data=False,
4341
model=None,
4442
random_seed=-1,
4543
parallel=False,
@@ -74,13 +72,9 @@ def sample_smc(
7472
Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
7573
the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
7674
It should be between 0 and 1.
77-
epsilon: float
78-
Standard deviation of the gaussian pseudo likelihood. Only works with `kernel = ABC`
79-
dist_func: str
80-
Distance function. The only available option is ``gaussian_kernel``
81-
sum_stat: str or callable
82-
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``, ``median``.
83-
If a callable is based it should return a number or a 1d numpy array.
75+
save_sim_data : bool
76+
Whether or not to save the simulated data. This parameters only work with the ABC kernel.
77+
The stored data corresponds to the posterior predictive distribution.
8478
model: Model (optional if in ``with`` context)).
8579
random_seed: int
8680
random seed
@@ -148,8 +142,15 @@ def sample_smc(
148142

149143
if chains is None:
150144
chains = max(2, cores)
145+
elif chains == 1:
146+
cores = 1
151147

152-
_log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)")
148+
_log.info(
149+
(
150+
f"Multiprocess sampling ({chains} chain{'s' if chains > 1 else ''} "
151+
f"in {cores} job{'s' if cores > 1 else ''})"
152+
)
153+
)
153154

154155
if random_seed == -1:
155156
random_seed = None
@@ -175,14 +176,12 @@ def sample_smc(
175176
tune_steps,
176177
p_acc_rate,
177178
threshold,
178-
epsilon,
179-
dist_func,
180-
sum_stat,
179+
save_sim_data,
181180
model,
182181
)
183182

184183
t1 = time.time()
185-
if parallel:
184+
if parallel and chains > 1:
186185
loggers = [_log] + [None] * (chains - 1)
187186
pool = mp.Pool(cores)
188187
results = pool.starmap(
@@ -196,7 +195,7 @@ def sample_smc(
196195
for i in range(chains):
197196
results.append((sample_smc_int(*params, random_seed[i], i, _log)))
198197

199-
traces, log_marginal_likelihoods, betas, accept_ratios, nsteps = zip(*results)
198+
traces, sim_data, log_marginal_likelihoods, betas, accept_ratios, nsteps = zip(*results)
200199
trace = MultiTrace(traces)
201200
trace.report._n_draws = draws
202201
trace.report._n_tune = 0
@@ -206,7 +205,10 @@ def sample_smc(
206205
trace.report.accept_ratios = accept_ratios
207206
trace.report.nsteps = nsteps
208207

209-
return trace
208+
if save_sim_data:
209+
return trace, {modelcontext(model).observed_RVs[0].name: np.array(sim_data)}
210+
else:
211+
return trace
210212

211213

212214
def sample_smc_int(
@@ -217,9 +219,7 @@ def sample_smc_int(
217219
tune_steps,
218220
p_acc_rate,
219221
threshold,
220-
epsilon,
221-
dist_func,
222-
sum_stat,
222+
save_sim_data,
223223
model,
224224
random_seed,
225225
chain,
@@ -234,9 +234,7 @@ def sample_smc_int(
234234
tune_steps=tune_steps,
235235
p_acc_rate=p_acc_rate,
236236
threshold=threshold,
237-
epsilon=epsilon,
238-
dist_func=dist_func,
239-
sum_stat=sum_stat,
237+
save_sim_data=save_sim_data,
240238
model=model,
241239
random_seed=random_seed,
242240
chain=chain,
@@ -262,4 +260,11 @@ def sample_smc_int(
262260
accept_ratios.append(smc.acc_rate)
263261
nsteps.append(smc.n_steps)
264262

265-
return smc.posterior_to_trace(), smc.log_marginal_likelihood, betas, accept_ratios, nsteps
263+
return (
264+
smc.posterior_to_trace(),
265+
smc.sim_data,
266+
smc.log_marginal_likelihood,
267+
betas,
268+
accept_ratios,
269+
nsteps,
270+
)

0 commit comments

Comments
 (0)