Skip to content

Commit ec0cd55

Browse files
kyleabeauchamptwiecki
authored andcommitted
Fix GPU tests in test_mixture and test_models (#2060)
* Switch MAP solver to one that supports float32 * Switch MAP solver to one that supports float32 * Adjust precision in test_model for float32 mode * Fix floatX in test_mixture * Fix floatX in test_quad
1 parent b1ace60 commit ec0cd55

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

pymc3/tests/test_mixture.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .helpers import SeededTest
55
from pymc3 import Dirichlet, Gamma, Metropolis, Mixture, Model, Normal, NormalMixture, Poisson, sample
6+
from pymc3.theanof import floatX
67

78

89
# Generate data
@@ -36,7 +37,7 @@ def setup_class(cls):
3637

3738
def test_mixture_list_of_normals(self):
3839
with Model() as model:
39-
w = Dirichlet('w', np.ones_like(self.norm_w))
40+
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
4041
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
4142
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
4243
Mixture('x_obs', w,
@@ -54,7 +55,7 @@ def test_mixture_list_of_normals(self):
5455

5556
def test_normal_mixture(self):
5657
with Model() as model:
57-
w = Dirichlet('w', np.ones_like(self.norm_w))
58+
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
5859
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
5960
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
6061
NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x)
@@ -70,7 +71,7 @@ def test_normal_mixture(self):
7071

7172
def test_poisson_mixture(self):
7273
with Model() as model:
73-
w = Dirichlet('w', np.ones_like(self.pois_w))
74+
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
7475
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
7576
Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x)
7677
step = Metropolis()
@@ -85,7 +86,7 @@ def test_poisson_mixture(self):
8586

8687
def test_mixture_list_of_poissons(self):
8788
with Model() as model:
88-
w = Dirichlet('w', np.ones_like(self.pois_w))
89+
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
8990
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
9091
Mixture('x_obs', w,
9192
[Poisson.dist(mu[0]), Poisson.dist(mu[1])],

pymc3/tests/test_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pymc3.distributions import HalfCauchy, Normal
66
from pymc3 import Potential, Deterministic
77
from pymc3.theanof import generator
8+
from .helpers import select_by_precision
89
import pytest
910

1011

@@ -171,7 +172,7 @@ def true_dens():
171172

172173
for i in range(10):
173174
_1, _2, _t = p1(), p2(), next(t)
174-
np.testing.assert_almost_equal(_1, _t)
175+
np.testing.assert_almost_equal(_1, _t, decimal=select_by_precision(float64=7, float32=2)) # Value O(-50,000)
175176
np.testing.assert_almost_equal(_1, _2)
176177
# Done
177178

pymc3/tests/test_quadpotential.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pymc3.step_methods.hmc import quadpotential
77
import pymc3
8+
from pymc3.theanof import floatX
89
import pytest
910

1011

@@ -22,7 +23,7 @@ def test_elemwise_posdef2():
2223

2324
def test_elemwise_velocity():
2425
scaling = np.array([1, 2, 3])
25-
x_ = np.ones_like(scaling)
26+
x_ = floatX(np.ones_like(scaling))
2627
x = tt.vector()
2728
x.tag.test_value = x_
2829
pot = quadpotential.quad_potential(scaling, True, False)
@@ -35,7 +36,7 @@ def test_elemwise_velocity():
3536

3637
def test_elemwise_energy():
3738
scaling = np.array([1, 2, 3])
38-
x_ = np.ones_like(scaling)
39+
x_ = floatX(np.ones_like(scaling))
3940
x = tt.vector()
4041
x.tag.test_value = x_
4142
pot = quadpotential.quad_potential(scaling, True, False)
@@ -50,7 +51,7 @@ def test_equal_diag():
5051
np.random.seed(42)
5152
for _ in range(3):
5253
diag = np.random.rand(5)
53-
x_ = np.random.randn(5)
54+
x_ = floatX(np.random.randn(5))
5455
x = tt.vector()
5556
x.tag.test_value = x_
5657
pots = [
@@ -80,7 +81,7 @@ def test_equal_dense():
8081
cov += 10 * np.eye(5)
8182
inv = np.linalg.inv(cov)
8283
assert np.allclose(inv.dot(cov), np.eye(5))
83-
x_ = np.random.randn(5)
84+
x_ = floatX(np.random.randn(5))
8485
x = tt.vector()
8586
x.tag.test_value = x_
8687
pots = [

0 commit comments

Comments
 (0)