13
13
# limitations under the License.
14
14
import multiprocessing
15
15
import os
16
+ import platform
16
17
17
18
import aesara
18
19
import aesara .tensor as at
20
+ import cloudpickle
19
21
import numpy as np
20
22
import pytest
21
23
@@ -86,7 +88,8 @@ def test_remote_pipe_closed():
86
88
87
89
88
90
@pytest .mark .xfail (reason = "Unclear" )
89
- def test_abort ():
91
+ @pytest .mark .parametrize ("mp_start_method" , ["spawn" , "fork" ])
92
+ def test_abort (mp_start_method ):
90
93
with pm .Model () as model :
91
94
a = pm .Normal ("a" , shape = 1 )
92
95
b = pm .HalfNormal ("b" )
@@ -95,8 +98,16 @@ def test_abort():
95
98
96
99
step = pm .CompoundStep ([step1 , step2 ])
97
100
101
+ # on Windows we cannot fork
102
+ if platform .system () == "Windows" and mp_start_method == "fork" :
103
+ return
104
+ if mp_start_method == "spawn" :
105
+ step_method_pickled = cloudpickle .dumps (step , protocol = - 1 )
106
+ else :
107
+ step_method_pickled = None
108
+
98
109
for abort in [False , True ]:
99
- ctx = multiprocessing .get_context ()
110
+ ctx = multiprocessing .get_context (mp_start_method )
100
111
proc = ps .ProcessAdapter (
101
112
10 ,
102
113
10 ,
@@ -105,7 +116,7 @@ def test_abort():
105
116
seed = 1 ,
106
117
mp_ctx = ctx ,
107
118
start = {"a" : floatX (np .array ([1.0 ])), "b_log__" : floatX (np .array (2.0 ))},
108
- step_method_pickled = None ,
119
+ step_method_pickled = step_method_pickled ,
109
120
)
110
121
proc .start ()
111
122
while True :
@@ -118,7 +129,8 @@ def test_abort():
118
129
proc .join ()
119
130
120
131
121
- def test_explicit_sample ():
132
+ @pytest .mark .parametrize ("mp_start_method" , ["spawn" , "fork" ])
133
+ def test_explicit_sample (mp_start_method ):
122
134
with pm .Model () as model :
123
135
a = pm .Normal ("a" , shape = 1 )
124
136
b = pm .HalfNormal ("b" )
@@ -127,7 +139,15 @@ def test_explicit_sample():
127
139
128
140
step = pm .CompoundStep ([step1 , step2 ])
129
141
130
- ctx = multiprocessing .get_context ()
142
+ # on Windows we cannot fork
143
+ if platform .system () == "Windows" and mp_start_method == "fork" :
144
+ return
145
+ if mp_start_method == "spawn" :
146
+ step_method_pickled = cloudpickle .dumps (step , protocol = - 1 )
147
+ else :
148
+ step_method_pickled = None
149
+
150
+ ctx = multiprocessing .get_context (mp_start_method )
131
151
proc = ps .ProcessAdapter (
132
152
10 ,
133
153
10 ,
@@ -136,7 +156,7 @@ def test_explicit_sample():
136
156
seed = 1 ,
137
157
mp_ctx = ctx ,
138
158
start = {"a" : floatX (np .array ([1.0 ])), "b_log__" : floatX (np .array (2.0 ))},
139
- step_method_pickled = None ,
159
+ step_method_pickled = step_method_pickled ,
140
160
)
141
161
proc .start ()
142
162
while True :
0 commit comments