Skip to content

Commit 8e8a477

Browse files
aloctavodiaricardoV94
authored andcommitted
Fix failing tests in test_parallel_sampling
1 parent 3544ae1 commit 8e8a477

File tree

1 file changed

+15
-21
lines changed

1 file changed

+15
-21
lines changed

pymc3/tests/test_parallel_sampling.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import pymc3 as pm
2626
import pymc3.parallel_sampling as ps
2727

28+
from pymc3.aesaraf import floatX
29+
2830

2931
def test_context():
3032
with pm.Model():
@@ -83,15 +85,13 @@ def test_remote_pipe_closed():
8385
pm.sample(step=step, mp_ctx="spawn", tune=2, draws=2, cores=2, chains=2)
8486

8587

86-
@pytest.mark.xfail(
87-
reason="Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
88-
)
88+
@pytest.mark.xfail(reason="Unclear")
8989
def test_abort():
9090
with pm.Model() as model:
9191
a = pm.Normal("a", shape=1)
92-
pm.HalfNormal("b")
93-
step1 = pm.NUTS([a])
94-
step2 = pm.Metropolis([model["b_log__"]])
92+
b = pm.HalfNormal("b")
93+
step1 = pm.NUTS([model.rvs_to_values[a]])
94+
step2 = pm.Metropolis([model.rvs_to_values[b]])
9595

9696
step = pm.CompoundStep([step1, step2])
9797

@@ -104,7 +104,7 @@ def test_abort():
104104
chain=3,
105105
seed=1,
106106
mp_ctx=ctx,
107-
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
107+
start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))},
108108
step_method_pickled=None,
109109
)
110110
proc.start()
@@ -118,15 +118,12 @@ def test_abort():
118118
proc.join()
119119

120120

121-
@pytest.mark.xfail(
122-
reason="Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
123-
)
124121
def test_explicit_sample():
125122
with pm.Model() as model:
126123
a = pm.Normal("a", shape=1)
127-
pm.HalfNormal("b")
128-
step1 = pm.NUTS([a])
129-
step2 = pm.Metropolis([model["b_log__"]])
124+
b = pm.HalfNormal("b")
125+
step1 = pm.NUTS([model.rvs_to_values[a]])
126+
step2 = pm.Metropolis([model.rvs_to_values[b]])
130127

131128
step = pm.CompoundStep([step1, step2])
132129

@@ -138,7 +135,7 @@ def test_explicit_sample():
138135
chain=3,
139136
seed=1,
140137
mp_ctx=ctx,
141-
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
138+
start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))},
142139
step_method_pickled=None,
143140
)
144141
proc.start()
@@ -153,19 +150,16 @@ def test_explicit_sample():
153150
proc.join()
154151

155152

156-
@pytest.mark.xfail(
157-
reason="Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
158-
)
159153
def test_iterator():
160154
with pm.Model() as model:
161155
a = pm.Normal("a", shape=1)
162-
pm.HalfNormal("b")
163-
step1 = pm.NUTS([a])
164-
step2 = pm.Metropolis([model["b_log__"]])
156+
b = pm.HalfNormal("b")
157+
step1 = pm.NUTS([model.rvs_to_values[a]])
158+
step2 = pm.Metropolis([model.rvs_to_values[b]])
165159

166160
step = pm.CompoundStep([step1, step2])
167161

168-
start = {"a": np.array([1.0]), "b_log__": np.array(2.0)}
162+
start = {"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))}
169163
sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, step, 0, False)
170164
with sampler:
171165
for draw in sampler:

0 commit comments

Comments
 (0)