Skip to content

Commit 78f9da0

Browse files
committed
use cloudpickle in smc sampling
1 parent 6352da0 commit 78f9da0

File tree

1 file changed

+43
-4
lines changed

1 file changed

+43
-4
lines changed

pymc3/smc/sample_smc.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from collections.abc import Iterable
2121

22+
import cloudpickle
2223
import numpy as np
2324

2425
from arviz import InferenceData
@@ -224,9 +225,12 @@ def sample_smc(
224225
pbars = [pbar] + [None] * (chains - 1)
225226

226227
pool = mp.Pool(cores)
228+
# "manually" (de)serialize params before/after multiprocessing
229+
params = tuple(cloudpickle.dumps(p) for p in params)
227230
results = pool.starmap(
228-
sample_smc_int, [(*params, random_seed[i], i, pbars[i]) for i in range(chains)]
231+
_sample_smc_int, [(*params, random_seed[i], i, pbars[i]) for i in range(chains)]
229232
)
233+
results = tuple(cloudpickle.loads(r) for r in results)
230234
pool.close()
231235
pool.join()
232236

@@ -237,7 +241,7 @@ def sample_smc(
237241
for i in range(chains):
238242
pbar.offset = 100 * i
239243
pbar.base_comment = f"Chain: {i+1}/{chains}"
240-
results.append(sample_smc_int(*params, random_seed[i], i, pbar))
244+
results.append(_sample_smc_int(*params, random_seed[i], i, pbar))
241245

242246
(
243247
traces,
@@ -316,7 +320,7 @@ def sample_smc(
316320
return posterior
317321

318322

319-
def sample_smc_int(
323+
def _sample_smc_int(
320324
draws,
321325
kernel,
322326
n_steps,
@@ -332,6 +336,36 @@ def sample_smc_int(
332336
progressbar=None,
333337
):
334338
"""Run one SMC instance."""
339+
in_out_pickled = type(model) == bytes
340+
if in_out_pickled:
341+
# function was called in multiprocessing context, deserialize first
342+
(
343+
draws,
344+
kernel,
345+
n_steps,
346+
start,
347+
tune_steps,
348+
p_acc_rate,
349+
threshold,
350+
save_sim_data,
351+
save_log_pseudolikelihood,
352+
model,
353+
) = map(
354+
cloudpickle.loads,
355+
(
356+
draws,
357+
kernel,
358+
n_steps,
359+
start,
360+
tune_steps,
361+
p_acc_rate,
362+
threshold,
363+
save_sim_data,
364+
save_log_pseudolikelihood,
365+
model,
366+
),
367+
)
368+
335369
smc = SMC(
336370
draws=draws,
337371
kernel=kernel,
@@ -375,7 +409,7 @@ def sample_smc_int(
375409
accept_ratios.append(smc.acc_rate)
376410
nsteps.append(smc.n_steps)
377411

378-
return (
412+
results = (
379413
smc.posterior_to_trace(),
380414
smc.sim_data,
381415
smc.log_marginal_likelihood,
@@ -384,3 +418,8 @@ def sample_smc_int(
384418
accept_ratios,
385419
nsteps,
386420
)
421+
422+
if in_out_pickled:
423+
results = cloudpickle.dumps(results)
424+
425+
return results

0 commit comments

Comments
 (0)