Skip to content

Commit 17d9f2e

Browse files
committed
Deprecate ABC specific code in SMC
1 parent 6482cef commit 17d9f2e

File tree

3 files changed

+66
-226
lines changed

3 files changed

+66
-226
lines changed

pymc3/smc/sample_smc.py

Lines changed: 29 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@
3636

3737
def sample_smc(
3838
draws=2000,
39-
kernel="metropolis",
39+
kernel=None,
4040
n_steps=25,
4141
*,
4242
start=None,
4343
tune_steps=True,
4444
p_acc_rate=0.85,
4545
threshold=0.5,
46-
save_sim_data=False,
47-
save_log_pseudolikelihood=True,
46+
save_sim_data=None,
47+
save_log_pseudolikelihood=None,
4848
model=None,
4949
random_seed=-1,
5050
parallel=None,
@@ -63,9 +63,6 @@ def sample_smc(
6363
draws: int
6464
The number of samples to draw from the posterior (i.e. last stage). And also the number of
6565
independent chains. Defaults to 2000.
66-
kernel: str
67-
Kernel method for the SMC sampler. Available option are ``metropolis`` (default) and `ABC`.
68-
Use `ABC` for likelihood free inference together with a ``pm.Simulator``.
6966
n_steps: int
7067
The number of steps of each Markov Chain. If ``tune_steps == True`` ``n_steps`` will be used
7168
for the first stage and for the others it will be determined automatically based on the
@@ -83,13 +80,6 @@ def sample_smc(
8380
Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
8481
the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
8582
It should be between 0 and 1.
86-
save_sim_data : bool
87-
Whether or not to save the simulated data. This parameter only works with the ABC kernel.
88-
The stored data corresponds to a samples from the posterior predictive distribution.
89-
save_log_pseudolikelihood : bool
90-
Whether or not to save the log pseudolikelihood values. This parameter only works with the
91-
ABC kernel. The stored data can be used to compute LOO or WAIC values. Computing LOO/WAIC
92-
values from log pseudolikelihood values is experimental.
9383
model: Model (optional if in ``with`` context)).
9484
random_seed: int
9585
random seed
@@ -157,6 +147,30 @@ def sample_smc(
157147
%282007%29133:7%28816%29>`__
158148
"""
159149

150+
if isinstance(kernel, str) and kernel.lower() == "abc":
151+
warnings.warn(
152+
f'The kernel "{kernel}" in sample_smc has been deprecated. '
153+
f"It is no longer needed to specify it.",
154+
DeprecationWarning,
155+
stacklevel=2,
156+
)
157+
158+
if save_sim_data is not None:
159+
warnings.warn(
160+
"save_sim_data has been deprecated. Use pm.sample_posterior_predictive "
161+
"to obtain the same type of samples.",
162+
DeprecationWarning,
163+
stacklevel=2,
164+
)
165+
166+
if save_log_pseudolikelihood is not None:
167+
warnings.warn(
168+
"save_log_pseudolikelihood has been deprecated. This information is "
169+
"now saved as log_likelihood in models with Simulator distributions.",
170+
DeprecationWarning,
171+
stacklevel=2,
172+
)
173+
160174
if parallel is not None:
161175
warnings.warn(
162176
"The argument parallel is deprecated, use the argument cores instead.",
@@ -199,22 +213,13 @@ def sample_smc(
199213
if not isinstance(random_seed, Iterable):
200214
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
201215

202-
if kernel.lower() == "abc":
203-
if len(model.observed_RVs) != 1:
204-
warnings.warn("SMC-ABC only works properly with models with one observed variable")
205-
if model.potentials:
206-
_log.info("Potentials will be added to the prior term")
207-
208216
params = (
209217
draws,
210-
kernel,
211218
n_steps,
212219
start,
213220
tune_steps,
214221
p_acc_rate,
215222
threshold,
216-
save_sim_data,
217-
save_log_pseudolikelihood,
218223
model,
219224
)
220225

@@ -245,9 +250,7 @@ def sample_smc(
245250

246251
(
247252
traces,
248-
sim_data,
249253
log_marginal_likelihoods,
250-
log_pseudolikelihood,
251254
betas,
252255
accept_ratios,
253256
nsteps,
@@ -263,7 +266,6 @@ def sample_smc(
263266
trace.report._n_draws = draws
264267
trace.report._n_tune = _n_tune
265268
trace.report.log_marginal_likelihood = log_marginal_likelihoods
266-
trace.report.log_pseudolikelihood = log_pseudolikelihood
267269
trace.report.betas = betas
268270
trace.report.accept_ratios = accept_ratios
269271
trace.report.nsteps = nsteps
@@ -313,23 +315,16 @@ def sample_smc(
313315
trace.report._run_convergence_checks(idata, model)
314316
trace.report._log_summary()
315317

316-
posterior = idata if return_inferencedata else trace
317-
if save_sim_data:
318-
return posterior, {modelcontext(model).observed_RVs[0].name: np.array(sim_data)}
319-
else:
320-
return posterior
318+
return idata if return_inferencedata else trace
321319

322320

323321
def _sample_smc_int(
324322
draws,
325-
kernel,
326323
n_steps,
327324
start,
328325
tune_steps,
329326
p_acc_rate,
330327
threshold,
331-
save_sim_data,
332-
save_log_pseudolikelihood,
333328
model,
334329
random_seed,
335330
chain,
@@ -339,43 +334,26 @@ def _sample_smc_int(
339334
in_out_pickled = type(model) == bytes
340335
if in_out_pickled:
341336
# function was called in multiprocessing context, deserialize first
342-
(
343-
draws,
344-
kernel,
345-
n_steps,
346-
start,
347-
tune_steps,
348-
p_acc_rate,
349-
threshold,
350-
save_sim_data,
351-
save_log_pseudolikelihood,
352-
model,
353-
) = map(
337+
(draws, n_steps, start, tune_steps, p_acc_rate, threshold, model,) = map(
354338
cloudpickle.loads,
355339
(
356340
draws,
357-
kernel,
358341
n_steps,
359342
start,
360343
tune_steps,
361344
p_acc_rate,
362345
threshold,
363-
save_sim_data,
364-
save_log_pseudolikelihood,
365346
model,
366347
),
367348
)
368349

369350
smc = SMC(
370351
draws=draws,
371-
kernel=kernel,
372352
n_steps=n_steps,
373353
start=start,
374354
tune_steps=tune_steps,
375355
p_acc_rate=p_acc_rate,
376356
threshold=threshold,
377-
save_sim_data=save_sim_data,
378-
save_log_pseudolikelihood=save_log_pseudolikelihood,
379357
model=model,
380358
random_seed=random_seed,
381359
chain=chain,
@@ -411,9 +389,7 @@ def _sample_smc_int(
411389

412390
results = (
413391
smc.posterior_to_trace(),
414-
smc.sim_data,
415392
smc.log_marginal_likelihood,
416-
smc.log_pseudolikelihood,
417393
betas,
418394
accept_ratios,
419395
nsteps,

0 commit comments

Comments
 (0)