Skip to content

Commit 6324408

Browse files
replace parallelize kwarg by reliance on cores setting
closes pymc-devs#3555
1 parent badd446 commit 6324408

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

pymc3/sampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
452452
if has_population_samplers:
453453
_log.info('Population sampling ({} chains)'.format(chains))
454454
_print_step_hierarchy(step)
455-
trace = _sample_population(**sample_args)
455+
trace = _sample_population(**sample_args, parallelize=cores > 1)
456456
else:
457457
_log.info('Sequential sampling ({} chains in 1 job)'.format(chains))
458458
_print_step_hierarchy(step)
@@ -689,7 +689,7 @@ def __init__(self, steppers, parallelize):
689689
if parallelize:
690690
try:
691691
# configure a child process for each stepper
692-
_log.info('Attempting to parallelize chains.')
692+
_log.info('Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`.')
693693
import multiprocessing
694694
for c, stepper in enumerate(tqdm(steppers)):
695695
slave_end, master_end = multiprocessing.Pipe()
@@ -714,7 +714,7 @@ def __init__(self, steppers, parallelize):
714714
_log.debug('Error was: ', exec_info=True)
715715
else:
716716
_log.info('Chains are not parallelized. You can enable this by passing '
717-
'pm.sample(parallelize=True).')
717+
'pm.sample(cores=2).')
718718
return super().__init__()
719719

720720
def __enter__(self):

pymc3/tests/test_step.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def test_nonparallelized_chains_are_random(self):
920920
x = Normal("x", 0, 1)
921921
for stepper in TestPopulationSamplers.steppers:
922922
step = stepper()
923-
trace = sample(chains=4, draws=20, tune=0, step=DEMetropolis())
923+
trace = sample(chains=4, cores=1, draws=20, tune=0, step=DEMetropolis())
924924
samples = np.array(trace.get_values("x", combine=False))[:, 5]
925925

926926
assert len(set(samples)) == 4, "Parallelized {} " "chains are identical.".format(
@@ -933,7 +933,7 @@ def test_parallelized_chains_are_random(self):
933933
x = Normal("x", 0, 1)
934934
for stepper in TestPopulationSamplers.steppers:
935935
step = stepper()
936-
trace = sample(chains=4, draws=20, tune=0, step=DEMetropolis(), parallelize=True)
936+
trace = sample(chains=4, cores=4, draws=20, tune=0, step=DEMetropolis())
937937
samples = np.array(trace.get_values("x", combine=False))[:, 5]
938938

939939
assert len(set(samples)) == 4, "Parallelized {} " "chains are identical.".format(

0 commit comments

Comments
 (0)