Skip to content

Commit 03b10de

Browse files
committed
Fix tests
1 parent dd21cc4 commit 03b10de

File tree

3 files changed

+20
-17
lines changed

3 files changed

+20
-17
lines changed

pymc3/parallel_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class _Process(multiprocessing.Process):
3030
"""
3131
def __init__(self, name, msg_pipe, step_method, shared_point,
3232
draws, tune, seed):
33-
super(_Process, self).__init__(daemon=True)
33+
super(_Process, self).__init__(daemon=True, name=name)
3434
self._msg_pipe = msg_pipe
3535
self._step_method = step_method
3636
self._shared_point = shared_point

pymc3/sampling.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -965,19 +965,9 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds):
965965
raise ValueError('Argument `trace` is invalid.')
966966

967967

968-
def _mp_sample(**kwargs):
969-
cores = kwargs.pop('cores')
970-
chain = kwargs.pop('chain')
971-
rseed = kwargs.pop('random_seed')
972-
start = kwargs.pop('start')
973-
chains = kwargs.pop('chains')
974-
draws = kwargs.pop('draws')
975-
tune = kwargs.pop('tune')
976-
step = kwargs.pop('step')
977-
progressbar = kwargs.pop('progressbar')
978-
use_mmap = kwargs.pop('use_mmap')
979-
model = kwargs.pop('model', None)
980-
trace = kwargs.pop('trace', None)
968+
def _mp_sample(draws, tune, step, chains, cores, chain, random_seed,
969+
start, progressbar, trace=None, model=None, use_mmap=False,
970+
**kwargs):
981971

982972
if sys.version_info.major >= 3:
983973
import pymc3.parallel_sampling as ps
@@ -1000,7 +990,8 @@ def _mp_sample(**kwargs):
1000990
traces.append(strace)
1001991

1002992
sampler = ps.ParallelSampler(
1003-
draws, tune, chains, cores, rseed, start, step, chain, progressbar)
993+
draws, tune, chains, cores, random_seed, start, step,
994+
chain, progressbar)
1004995
try:
1005996
with sampler:
1006997
for draw in sampler:
@@ -1022,8 +1013,12 @@ def _mp_sample(**kwargs):
10221013
else:
10231014
chain_nums = list(range(chain, chain + chains))
10241015
pbars = [progressbar] + [False] * (chains - 1)
1025-
jobs = (delayed(_sample)(*args, **kwargs)
1026-
for args in zip(chain_nums, pbars, rseed, start))
1016+
jobs = (delayed(_sample)(
1017+
chain=args[0], progressbar=args[1], random_seed=args[2],
1018+
start=args[3], draws=draws, step=step, trace=trace,
1019+
tune=tune, model=model, **kwargs
1020+
)
1021+
for args in zip(chain_nums, pbars, random_seed, start))
10271022
if use_mmap:
10281023
traces = Parallel(n_jobs=cores)(jobs)
10291024
else:

pymc3/tests/test_parallel_sampling.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import time
2+
import sys
3+
import pytest
24

35
import pymc3.parallel_sampling as ps
46
import pymc3 as pm
57

68

9+
@pytest.mark.skipif(sys.version_info < (3,3),
10+
reason="requires python3.3")
711
def test_abort():
812
with pm.Model() as model:
913
a = pm.Normal('a', shape=1)
@@ -21,6 +25,8 @@ def test_abort():
2125
proc.join()
2226

2327

28+
@pytest.mark.skipif(sys.version_info < (3,3),
29+
reason="requires python3.3")
2430
def test_explicit_sample():
2531
with pm.Model() as model:
2632
a = pm.Normal('a', shape=1)
@@ -46,6 +52,8 @@ def test_explicit_sample():
4652
print(time.time() - start)
4753

4854

55+
@pytest.mark.skipif(sys.version_info < (3,3),
56+
reason="requires python3.3")
4957
def test_iterator():
5058
with pm.Model() as model:
5159
a = pm.Normal('a', shape=1)

0 commit comments

Comments
 (0)