@@ -222,52 +222,54 @@ def sample_smc(
222
222
)
223
223
224
224
t1 = time .time ()
225
+
225
226
if cores > 1 :
226
- pbar = progress_bar ((), total = 100 , display = progressbar )
227
- pbar .update (0 )
228
- pbars = [pbar ] + [None ] * (chains - 1 )
229
-
230
- pool = mp .Pool (cores )
231
-
232
- # "manually" (de)serialize params before/after multiprocessing
233
- params = tuple (cloudpickle .dumps (p ) for p in params )
234
- kernel_kwargs = {key : cloudpickle .dumps (value ) for key , value in kernel_kwargs .items ()}
235
- results = _starmap_with_kwargs (
236
- pool ,
237
- _sample_smc_int ,
238
- [(* params , random_seed [chain ], chain , pbars [chain ]) for chain in range (chains )],
239
- repeat (kernel_kwargs ),
227
+ results = run_chains_parallel (
228
+ chains , progressbar , _sample_smc_int , params , random_seed , kernel_kwargs , cores
240
229
)
241
- results = tuple (cloudpickle .loads (r ) for r in results )
242
- pool .close ()
243
- pool .join ()
244
-
245
230
else :
246
- results = []
247
- pbar = progress_bar ((), total = 100 * chains , display = progressbar )
248
- pbar .update (0 )
249
- for chain in range (chains ):
250
- pbar .offset = 100 * chain
251
- pbar .base_comment = f"Chain: { chain + 1 } /{ chains } "
252
- results .append (
253
- _sample_smc_int (* params , random_seed [chain ], chain , pbar , ** kernel_kwargs )
254
- )
255
-
231
+ results = run_chains_sequential (
232
+ chains , progressbar , _sample_smc_int , params , random_seed , kernel_kwargs
233
+ )
256
234
(
257
235
traces ,
258
236
sample_stats ,
259
237
sample_settings ,
260
238
) = zip (* results )
261
239
262
240
trace = MultiTrace (traces )
263
- idata = None
264
241
265
- # Save sample_stats
266
242
_t_sampling = time .time () - t1
243
+ sample_stats , idata = _save_sample_stats (
244
+ sample_settings ,
245
+ sample_stats ,
246
+ chains ,
247
+ trace ,
248
+ return_inferencedata ,
249
+ _t_sampling ,
250
+ idata_kwargs ,
251
+ model ,
252
+ )
253
+
254
+ if compute_convergence_checks :
255
+ _compute_convergence_checks (idata , draws , model , trace )
256
+ return idata if return_inferencedata else trace
257
+
258
+
259
+ def _save_sample_stats (
260
+ sample_settings ,
261
+ sample_stats ,
262
+ chains ,
263
+ trace ,
264
+ return_inferencedata ,
265
+ _t_sampling ,
266
+ idata_kwargs ,
267
+ model ,
268
+ ):
267
269
sample_settings_dict = sample_settings [0 ]
268
270
sample_settings_dict ["_t_sampling" ] = _t_sampling
269
-
270
271
sample_stats_dict = sample_stats [0 ]
272
+
271
273
if chains > 1 :
272
274
# Collect the stat values from each chain in a single list
273
275
for stat in sample_stats [0 ].keys ():
@@ -281,6 +283,7 @@ def sample_smc(
281
283
setattr (trace .report , stat , value )
282
284
for stat , value in sample_settings_dict .items ():
283
285
setattr (trace .report , stat , value )
286
+ idata = None
284
287
else :
285
288
for stat , value in sample_stats_dict .items ():
286
289
if chains > 1 :
@@ -303,19 +306,20 @@ def sample_smc(
303
306
idata = to_inference_data (trace , ** ikwargs )
304
307
idata = InferenceData (** idata , sample_stats = sample_stats )
305
308
306
- if compute_convergence_checks :
307
- if draws < 100 :
308
- warnings .warn (
309
- "The number of samples is too small to check convergence reliably." ,
310
- stacklevel = 2 ,
311
- )
312
- else :
313
- if idata is None :
314
- idata = to_inference_data (trace , log_likelihood = False )
315
- trace .report ._run_convergence_checks (idata , model )
316
- trace .report ._log_summary ()
309
+ return sample_stats , idata
317
310
318
- return idata if return_inferencedata else trace
311
+
312
+ def _compute_convergence_checks (idata , draws , model , trace ):
313
+ if draws < 100 :
314
+ warnings .warn (
315
+ "The number of samples is too small to check convergence reliably." ,
316
+ stacklevel = 2 ,
317
+ )
318
+ else :
319
+ if idata is None :
320
+ idata = to_inference_data (trace , log_likelihood = False )
321
+ trace .report ._run_convergence_checks (idata , model )
322
+ trace .report ._log_summary ()
319
323
320
324
321
325
def _sample_smc_int (
@@ -391,6 +395,39 @@ def _sample_smc_int(
391
395
return results
392
396
393
397
398
+ def run_chains_parallel (chains , progressbar , to_run , params , random_seed , kernel_kwargs , cores ):
399
+ pbar = progress_bar ((), total = 100 , display = progressbar )
400
+ pbar .update (0 )
401
+ pbars = [pbar ] + [None ] * (chains - 1 )
402
+
403
+ pool = mp .Pool (cores )
404
+
405
+ # "manually" (de)serialize params before/after multiprocessing
406
+ params = tuple (cloudpickle .dumps (p ) for p in params )
407
+ kernel_kwargs = {key : cloudpickle .dumps (value ) for key , value in kernel_kwargs .items ()}
408
+ results = _starmap_with_kwargs (
409
+ pool ,
410
+ to_run ,
411
+ [(* params , random_seed [chain ], chain , pbars [chain ]) for chain in range (chains )],
412
+ repeat (kernel_kwargs ),
413
+ )
414
+ results = tuple (cloudpickle .loads (r ) for r in results )
415
+ pool .close ()
416
+ pool .join ()
417
+ return results
418
+
419
+
420
+ def run_chains_sequential (chains , progressbar , to_run , params , random_seed , kernel_kwargs ):
421
+ results = []
422
+ pbar = progress_bar ((), total = 100 * chains , display = progressbar )
423
+ pbar .update (0 )
424
+ for chain in range (chains ):
425
+ pbar .offset = 100 * chain
426
+ pbar .base_comment = f"Chain: { chain + 1 } /{ chains } "
427
+ results .append (to_run (* params , random_seed [chain ], chain , pbar , ** kernel_kwargs ))
428
+ return results
429
+
430
+
394
431
def _starmap_with_kwargs (pool , fn , args_iter , kwargs_iter ):
395
432
# Helper function to allow kwargs with Pool.starmap
396
433
# Copied from https://stackoverflow.com/a/53173433/13311693
0 commit comments