Skip to content

Commit dc574b7

Browse files
authored
improve ABC sampler (#3940)
* Expand ABC features. * valueerror * update notebook * remove unused import update release notes * fix notebook style and change order params argument
1 parent c6bba80 commit dc574b7

File tree

6 files changed

+246
-190
lines changed

6 files changed

+246
-190
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
- `pm.Data` container can now be used as input for other random variables (issue [#3842](https://github.com/pymc-devs/pymc3/issues/3842), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)).
2020
- Plots and Stats API sections now link to ArviZ documentation [#3927](https://github.com/pymc-devs/pymc3/pull/3927)
2121
- Add `SamplerReport` with properties `n_draws`, `t_sampling` and `n_tune` to SMC. `n_tune` is always 0 [#3931](https://github.com/pymc-devs/pymc3/issues/3931).
22+
- SMC-ABC: add option to define summary statistics, allow to sample from more complex models, remove redundant distances [#3940](https://github.com/pymc-devs/pymc3/issues/3940)
2223

2324
### Maintenance
2425
- Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796).

docs/source/notebooks/SMC-ABC_Lotka-Volterra_example.ipynb

Lines changed: 183 additions & 139 deletions
Large diffs are not rendered by default.

pymc3/distributions/simulator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@
1919

2020

2121
class Simulator(NoDistribution):
22-
def __init__(self, function, *args, **kwargs):
22+
def __init__(self, function, *args, params=None, **kwargs):
2323
"""
2424
This class stores a function defined by the user in python language.
2525
2626
function: function
2727
Simulation function defined by the user.
28+
params: list
29+
Parameters passed to function.
2830
*args and **kwargs:
2931
Arguments and keywords arguments that the function takes.
3032
"""
3133

3234
self.function = function
35+
self.params = params
3336
observed = self.data
3437
super().__init__(shape=np.prod(observed.shape), dtype=observed.dtype, *args, **kwargs)
3538

pymc3/smc/sample_smc.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def sample_smc(
2828
p_acc_rate=0.99,
2929
threshold=0.5,
3030
epsilon=1.0,
31-
dist_func="absolute_error",
32-
sum_stat=False,
31+
dist_func="gaussian_kernel",
32+
sum_stat="identity",
3333
progressbar=False,
3434
model=None,
3535
random_seed=-1,
@@ -71,11 +71,10 @@ def sample_smc(
7171
epsilon: float
7272
Standard deviation of the gaussian pseudo likelihood. Only works with `kernel = ABC`
7373
dist_func: str
74-
Distance function. Available options are ``absolute_error`` (default) and
75-
``sum_of_squared_distance``. Only works with ``kernel = ABC``
76-
sum_stat: bool
77-
Whether to use or not a summary statistics. Defaults to False. Only works with
78-
``kernel = ABC``
74+
Distance function. The only available option is ``gaussian_kernel``
75+
sum_stat: str or callable
76+
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``, ``median``.
77+
If a callable is based it should return a number or a 1d numpy array.
7978
progressbar: bool
8079
Flag for displaying a progress bar. Defaults to False.
8180
model: Model (optional if in ``with`` context)).

pymc3/smc/smc.py

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from ..step_methods.metropolis import MultivariateNormalProposal
3232
from ..backends.ndarray import NDArray
3333
from ..backends.base import MultiTrace
34-
from ..util import is_transformed_name
3534

3635
EXPERIMENTAL_WARNING = (
3736
"Warning: SMC-ABC methods are experimental step methods and not yet"
@@ -53,7 +52,7 @@ def __init__(
5352
threshold=0.5,
5453
epsilon=1.0,
5554
dist_func="absolute_error",
56-
sum_stat=False,
55+
sum_stat="Identity",
5756
progressbar=False,
5857
model=None,
5958
random_seed=-1,
@@ -140,6 +139,7 @@ def setup_kernel(self):
140139
self.epsilon,
141140
simulator.observations,
142141
simulator.distribution.function,
142+
[v.name for v in simulator.distribution.params],
143143
self.model,
144144
self.var_info,
145145
self.variables,
@@ -281,7 +281,7 @@ def mutate(self):
281281
self.priors[draw],
282282
self.likelihoods[draw],
283283
draw,
284-
*parameters
284+
*parameters,
285285
)
286286
for draw in iterator
287287
]
@@ -307,7 +307,7 @@ def posterior_to_trace(self):
307307
size = 0
308308
for var in varnames:
309309
shape, new_size = self.var_info[var]
310-
value.append(self.posterior[i][size: size + new_size].reshape(shape))
310+
value.append(self.posterior[i][size : size + new_size].reshape(shape))
311311
size += new_size
312312
strace.record({k: v for k, v in zip(varnames, value)})
313313
return MultiTrace([strace])
@@ -389,7 +389,16 @@ class PseudoLikelihood:
389389
"""
390390

391391
def __init__(
392-
self, epsilon, observations, function, model, var_info, variables, distance, sum_stat
392+
self,
393+
epsilon,
394+
observations,
395+
function,
396+
params,
397+
model,
398+
var_info,
399+
variables,
400+
distance,
401+
sum_stat,
393402
):
394403
"""
395404
epsilon: float
@@ -398,34 +407,48 @@ def __init__(
398407
observed data
399408
function: python function
400409
data simulator
410+
params: list
411+
names of the variables parameterizing the simulator.
401412
model: PyMC3 model
402413
var_info: dict
403414
generated by ``SMC.initialize_population``
404-
distance: str
405-
Distance function. Available options are ``absolute_error`` (default) and
406-
``sum_of_squared_distance``.
407-
sum_stat: bool
408-
Whether to use or not a summary statistics.
415+
distance : str or callable
416+
Distance function. The only available option is ``gaussian_kernel``
417+
sum_stat: str or callable
418+
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``,
419+
``median``. The user can pass any valid Python function
409420
"""
410421
self.epsilon = epsilon
411-
self.observations = observations
412422
self.function = function
423+
self.params = params
413424
self.model = model
414425
self.var_info = var_info
415426
self.variables = variables
416427
self.varnames = [v.name for v in self.variables]
417428
self.unobserved_RVs = [v.name for v in self.model.unobserved_RVs]
418-
self.kernel = self.gauss_kernel
419-
self.dist_func = distance
420-
self.sum_stat = sum_stat
421429
self.get_unobserved_fn = self.model.fastfn(self.model.unobserved_RVs)
422430

423-
if distance == "absolute_error":
424-
self.dist_func = self.absolute_error
425-
elif distance == "sum_of_squared_distance":
426-
self.dist_func = self.sum_of_squared_distance
431+
if sum_stat == "identity":
432+
self.sum_stat = lambda x: x
433+
elif sum_stat == "sorted":
434+
self.sum_stat = np.sort
435+
elif sum_stat == "mean":
436+
self.sum_stat = np.mean
437+
elif sum_stat == "median":
438+
self.sum_stat = np.median
439+
elif hasattr(sum_stat, "__call__"):
440+
self.sum_stat = sum_stat
441+
else:
442+
raise ValueError(f"The summary statistics {sum_stat} is not implemented")
443+
444+
self.observations = self.sum_stat(observations)
445+
446+
if distance == "gaussian_kernel":
447+
self.distance = self.gaussian_kernel
448+
elif hasattr(distance, "__call__"):
449+
self.distance = distance
427450
else:
428-
raise ValueError("Distance metric not understood")
451+
raise ValueError(f"The distance metric {distance} is not implemented")
429452

430453
def posterior_to_function(self, posterior):
431454
model = self.model
@@ -436,32 +459,18 @@ def posterior_to_function(self, posterior):
436459
size = 0
437460
for var in self.variables:
438461
shape, new_size = var_info[var.name]
439-
varvalues.append(posterior[size: size + new_size].reshape(shape))
462+
varvalues.append(posterior[size : size + new_size].reshape(shape))
440463
size += new_size
441464
point = {k: v for k, v in zip(self.varnames, varvalues)}
442465
for varname, value in zip(self.unobserved_RVs, self.get_unobserved_fn(point)):
443-
if not is_transformed_name(varname):
466+
if varname in self.params:
444467
samples[varname] = value
445468
return samples
446469

447-
def gauss_kernel(self, value):
448-
epsilon = self.epsilon
449-
return (-(value ** 2) / epsilon ** 2 + np.log(1 / (2 * np.pi * epsilon ** 2))) / 2.0
450-
451-
def absolute_error(self, a, b):
452-
if self.sum_stat:
453-
return np.abs(a.mean() - b.mean())
454-
else:
455-
return np.mean(np.atleast_2d(np.abs(a - b)))
456-
457-
def sum_of_squared_distance(self, a, b):
458-
if self.sum_stat:
459-
return np.sum(np.atleast_2d((a.mean() - b.mean()) ** 2))
460-
else:
461-
return np.mean(np.sum(np.atleast_2d((a - b) ** 2)))
470+
def gaussian_kernel(self, obs_data, sim_data):
471+
return np.sum(-0.5 * ((obs_data - sim_data) / self.epsilon) ** 2)
462472

463473
def __call__(self, posterior):
464474
func_parameters = self.posterior_to_function(posterior)
465-
sim_data = self.function(**func_parameters)
466-
value = self.dist_func(self.observations, sim_data)
467-
return self.kernel(value)
475+
sim_data = self.sum_stat(self.function(**func_parameters))
476+
return self.distance(self.observations, sim_data)

pymc3/tests/test_smc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,19 @@ def test_start(self):
9898
class TestSMCABC(SeededTest):
9999
def setup_class(self):
100100
super().setup_class()
101-
self.data = np.sort(np.random.normal(loc=0, scale=1, size=1000))
101+
self.data = np.random.normal(loc=0, scale=1, size=1000)
102102

103103
def normal_sim(a, b):
104-
return np.sort(np.random.normal(a, b, 1000))
104+
return np.random.normal(a, b, 1000)
105105

106106
with pm.Model() as self.SMABC_test:
107107
a = pm.Normal("a", mu=0, sd=5)
108108
b = pm.HalfNormal("b", sd=2)
109-
s = pm.Simulator("s", normal_sim, observed=self.data)
109+
s = pm.Simulator("s", normal_sim, params=(a, b), observed=self.data)
110110

111111
def test_one_gaussian(self):
112112
with self.SMABC_test:
113-
trace = pm.sample_smc(draws=2000, kernel="ABC", epsilon=0.1)
113+
trace = pm.sample_smc(draws=1000, kernel="ABC", sum_stat="sorted", epsilon=1)
114114

115115
np.testing.assert_almost_equal(self.data.mean(), trace["a"].mean(), decimal=2)
116116
np.testing.assert_almost_equal(self.data.std(), trace["b"].mean(), decimal=1)

0 commit comments

Comments
 (0)