Skip to content

Commit 8cc6f8f

Browse files
committed
Add test for remote process crash
1 parent f0b5d28 commit 8cc6f8f

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

pymc3/tests/test_parallel_sampling.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import multiprocessing
15+
import os
1516

1617
import pytest
1718
import pymc3.parallel_sampling as ps
1819
import pymc3 as pm
20+
import theano
21+
import theano.tensor as tt
22+
import numpy as np
1923

2024

2125
def test_context():
@@ -46,6 +50,31 @@ def test_bad_unpickle():
4650
assert 'could not be unpickled' in str(exc_info.getrepr(style='short'))
4751

4852

53+
@theano.as_op(
54+
[
55+
tt.dvector if theano.config.floatX == "float64" else tt.fvector,
56+
tt.iscalar,
57+
],
58+
[tt.dvector if theano.config.floatX == "float64" else tt.fvector],
59+
)
60+
def _crash_remote_process(a, master_pid):
61+
if os.getpid() != master_pid:
62+
os.exit(0)
63+
return 2 * np.array(a)
64+
65+
66+
def test_remote_pipe_closed():
67+
master_pid = os.getpid()
68+
with pm.Model():
69+
x = pm.Normal('x', shape=2, mu=0.1)
70+
tt_pid = tt.as_tensor_variable(np.array(master_pid, dtype='int32'))
71+
pm.Normal('y', mu=_crash_remote_process(x, tt_pid), shape=2)
72+
73+
step = pm.Metropolis()
74+
with pytest.raises(RuntimeError, match="Chain [0-9] failed"):
75+
pm.sample(step=step, mp_ctx='spawn', tune=2, draws=2, cores=2, chains=2)
76+
77+
4978
def test_abort():
5079
with pm.Model() as model:
5180
a = pm.Normal('a', shape=1)

0 commit comments

Comments
 (0)