Skip to content

Commit 211165e

Browse files
SpaakricardoV94
authored andcommitted
pickle step method for ProcessAdapter when mp_method is spawn
1 parent 8e8a477 commit 211165e

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

pymc3/parallel_sampling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,10 @@ def __init__(
257257
if step_method_pickled is not None:
258258
step_method_send = step_method_pickled
259259
else:
260+
if mp_ctx.get_start_method() == "spawn":
261+
raise ValueError(
262+
"please provide a pre-pickled step method when multiprocessing start method is 'spawn'"
263+
)
260264
step_method_send = step_method
261265

262266
self._process = mp_ctx.Process(

pymc3/tests/test_parallel_sampling.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414
import multiprocessing
1515
import os
16+
import platform
1617

1718
import aesara
1819
import aesara.tensor as at
20+
import cloudpickle
1921
import numpy as np
2022
import pytest
2123

@@ -86,7 +88,8 @@ def test_remote_pipe_closed():
8688

8789

8890
@pytest.mark.xfail(reason="Unclear")
89-
def test_abort():
91+
@pytest.mark.parametrize("mp_start_method", ["spawn", "fork"])
92+
def test_abort(mp_start_method):
9093
with pm.Model() as model:
9194
a = pm.Normal("a", shape=1)
9295
b = pm.HalfNormal("b")
@@ -95,8 +98,16 @@ def test_abort():
9598

9699
step = pm.CompoundStep([step1, step2])
97100

101+
# on Windows we cannot fork
102+
if platform.system() == "Windows" and mp_start_method == "fork":
103+
return
104+
if mp_start_method == "spawn":
105+
step_method_pickled = cloudpickle.dumps(step, protocol=-1)
106+
else:
107+
step_method_pickled = None
108+
98109
for abort in [False, True]:
99-
ctx = multiprocessing.get_context()
110+
ctx = multiprocessing.get_context(mp_start_method)
100111
proc = ps.ProcessAdapter(
101112
10,
102113
10,
@@ -105,7 +116,7 @@ def test_abort():
105116
seed=1,
106117
mp_ctx=ctx,
107118
start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))},
108-
step_method_pickled=None,
119+
step_method_pickled=step_method_pickled,
109120
)
110121
proc.start()
111122
while True:
@@ -118,7 +129,8 @@ def test_abort():
118129
proc.join()
119130

120131

121-
def test_explicit_sample():
132+
@pytest.mark.parametrize("mp_start_method", ["spawn", "fork"])
133+
def test_explicit_sample(mp_start_method):
122134
with pm.Model() as model:
123135
a = pm.Normal("a", shape=1)
124136
b = pm.HalfNormal("b")
@@ -127,7 +139,15 @@ def test_explicit_sample():
127139

128140
step = pm.CompoundStep([step1, step2])
129141

130-
ctx = multiprocessing.get_context()
142+
# on Windows we cannot fork
143+
if platform.system() == "Windows" and mp_start_method == "fork":
144+
return
145+
if mp_start_method == "spawn":
146+
step_method_pickled = cloudpickle.dumps(step, protocol=-1)
147+
else:
148+
step_method_pickled = None
149+
150+
ctx = multiprocessing.get_context(mp_start_method)
131151
proc = ps.ProcessAdapter(
132152
10,
133153
10,
@@ -136,7 +156,7 @@ def test_explicit_sample():
136156
seed=1,
137157
mp_ctx=ctx,
138158
start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))},
139-
step_method_pickled=None,
159+
step_method_pickled=step_method_pickled,
140160
)
141161
proc.start()
142162
while True:

0 commit comments

Comments
 (0)