19
19
20
20
from collections .abc import Iterable
21
21
22
+ import cloudpickle
22
23
import numpy as np
23
24
24
25
from arviz import InferenceData
@@ -224,9 +225,12 @@ def sample_smc(
224
225
pbars = [pbar ] + [None ] * (chains - 1 )
225
226
226
227
pool = mp .Pool (cores )
228
+ # "manually" (de)serialize params before/after multiprocessing
229
+ params = tuple (cloudpickle .dumps (p ) for p in params )
227
230
results = pool .starmap (
228
- sample_smc_int , [(* params , random_seed [i ], i , pbars [i ]) for i in range (chains )]
231
+ _sample_smc_int , [(* params , random_seed [i ], i , pbars [i ]) for i in range (chains )]
229
232
)
233
+ results = tuple (cloudpickle .loads (r ) for r in results )
230
234
pool .close ()
231
235
pool .join ()
232
236
@@ -237,7 +241,7 @@ def sample_smc(
237
241
for i in range (chains ):
238
242
pbar .offset = 100 * i
239
243
pbar .base_comment = f"Chain: { i + 1 } /{ chains } "
240
- results .append (sample_smc_int (* params , random_seed [i ], i , pbar ))
244
+ results .append (_sample_smc_int (* params , random_seed [i ], i , pbar ))
241
245
242
246
(
243
247
traces ,
@@ -316,7 +320,7 @@ def sample_smc(
316
320
return posterior
317
321
318
322
319
- def sample_smc_int (
323
+ def _sample_smc_int (
320
324
draws ,
321
325
kernel ,
322
326
n_steps ,
@@ -332,6 +336,36 @@ def sample_smc_int(
332
336
progressbar = None ,
333
337
):
334
338
"""Run one SMC instance."""
339
+ in_out_pickled = type (model ) == bytes
340
+ if in_out_pickled :
341
+ # function was called in multiprocessing context, deserialize first
342
+ (
343
+ draws ,
344
+ kernel ,
345
+ n_steps ,
346
+ start ,
347
+ tune_steps ,
348
+ p_acc_rate ,
349
+ threshold ,
350
+ save_sim_data ,
351
+ save_log_pseudolikelihood ,
352
+ model ,
353
+ ) = map (
354
+ cloudpickle .loads ,
355
+ (
356
+ draws ,
357
+ kernel ,
358
+ n_steps ,
359
+ start ,
360
+ tune_steps ,
361
+ p_acc_rate ,
362
+ threshold ,
363
+ save_sim_data ,
364
+ save_log_pseudolikelihood ,
365
+ model ,
366
+ ),
367
+ )
368
+
335
369
smc = SMC (
336
370
draws = draws ,
337
371
kernel = kernel ,
@@ -375,7 +409,7 @@ def sample_smc_int(
375
409
accept_ratios .append (smc .acc_rate )
376
410
nsteps .append (smc .n_steps )
377
411
378
- return (
412
+ results = (
379
413
smc .posterior_to_trace (),
380
414
smc .sim_data ,
381
415
smc .log_marginal_likelihood ,
@@ -384,3 +418,8 @@ def sample_smc_int(
384
418
accept_ratios ,
385
419
nsteps ,
386
420
)
421
+
422
+ if in_out_pickled :
423
+ results = cloudpickle .dumps (results )
424
+
425
+ return results
0 commit comments