Skip to content

Commit 5e01fc4

Browse files
committed
Reenable more tests
1 parent 75c2e1e commit 5e01fc4

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

.github/workflows/pytest.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,13 @@ jobs:
2727
# 6th block: These have some XFAILs
2828
- |
2929
--ignore=pymc3/tests/test_distributions_timeseries.py
30-
--ignore=pymc3/tests/test_missing.py
3130
--ignore=pymc3/tests/test_mixture.py
3231
--ignore=pymc3/tests/test_model_graph.py
3332
--ignore=pymc3/tests/test_modelcontext.py
3433
--ignore=pymc3/tests/test_parallel_sampling.py
3534
--ignore=pymc3/tests/test_profile.py
36-
--ignore=pymc3/tests/test_random.py
37-
--ignore=pymc3/tests/test_shared.py
3835
--ignore=pymc3/tests/test_smc.py
39-
--ignore=pymc3/tests/test_starting.py
4036
--ignore=pymc3/tests/test_step.py
41-
--ignore=pymc3/tests/test_tracetab.py
4237
--ignore=pymc3/tests/test_tuning.py
4338
--ignore=pymc3/tests/test_types.py
4439
--ignore=pymc3/tests/test_variational_inference.py

pymc3/tests/test_missing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,12 @@ def test_interval_missing_observations():
121121

122122
assert {"theta1", "theta2"} <= set(prior_trace.keys())
123123

124-
trace = sample(chains=1, draws=50, compute_convergence_checks=False)
124+
trace = sample(
125+
chains=1,
126+
draws=50,
127+
compute_convergence_checks=False,
128+
return_inferencedata=False,
129+
)
125130

126131
assert np.all(0 < trace["theta1_missing"].mean(0))
127132
assert np.all(0 < trace["theta2_missing"].mean(0))

pymc3/tests/test_starting.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16-
17-
from pytest import raises
16+
import pytest
1817

1918
from pymc3 import (
2019
Beta,
@@ -47,6 +46,7 @@ def test_accuracy_non_normal():
4746
close_to(newstart["x"], mu, select_by_precision(float64=1e-5, float32=1e-4))
4847

4948

49+
@pytest.mark.xfail(reason="find_MAP fails with derivatives")
5050
def test_find_MAP_discrete():
5151
tol = 2.0 ** -11
5252
alpha = 4
@@ -68,12 +68,15 @@ def test_find_MAP_discrete():
6868
assert map_est2["ss"] == 14
6969

7070

71+
@pytest.mark.xfail(reason="find_MAP fails with derivatives")
7172
def test_find_MAP_no_gradient():
7273
_, model = simple_arbitrary_det()
7374
with model:
7475
find_MAP()
7576

7677

78+
@pytest.mark.skip(reason="test is slow because it's failing")
79+
@pytest.mark.xfail(reason="find_MAP fails with derivatives")
7780
def test_find_MAP():
7881
tol = 2.0 ** -11 # 16 bit machine epsilon, a low bar
7982
data = np.random.randn(100)
@@ -106,8 +109,8 @@ def test_find_MAP_issue_4488():
106109
map_estimate = find_MAP()
107110

108111
assert not set.difference({"x_missing", "x_missing_log__", "y"}, set(map_estimate.keys()))
109-
assert np.isclose(map_estimate["x_missing"], 0.2)
110-
np.testing.assert_array_equal(map_estimate["y"], [2.0, map_estimate["x_missing"][0] + 1])
112+
np.testing.assert_allclose(map_estimate["x_missing"], 0.2, rtol=1e-5, atol=1e-5)
113+
np.testing.assert_allclose(map_estimate["y"], [2.0, map_estimate["x_missing"][0] + 1])
111114

112115

113116
def test_allinmodel():
@@ -120,11 +123,16 @@ def test_allinmodel():
120123
x2 = Normal("x2", mu=0, sigma=1)
121124
y2 = Normal("y2", mu=0, sigma=1)
122125

126+
x1 = model1.rvs_to_values[x1]
127+
y1 = model1.rvs_to_values[y1]
128+
x2 = model2.rvs_to_values[x2]
129+
y2 = model2.rvs_to_values[y2]
130+
123131
starting.allinmodel([x1, y1], model1)
124132
starting.allinmodel([x1], model1)
125-
with raises(ValueError, match=r"Some variables not in the model: \['x2', 'y2'\]"):
133+
with pytest.raises(ValueError, match=r"Some variables not in the model: \['x2', 'y2'\]"):
126134
starting.allinmodel([x2, y2], model1)
127-
with raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
135+
with pytest.raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
128136
starting.allinmodel([x2, y1], model1)
129-
with raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
137+
with pytest.raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
130138
starting.allinmodel([x2], model1)

0 commit comments

Comments
 (0)