Skip to content

Commit 6c9e848

Browse files
kyleabeauchamptwiecki
authored andcommitted
Try to fix test_step
1 parent 96b6c6e commit 6c9e848

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

pymc3/tests/test_hmc.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from . import models
55
from pymc3.step_methods.hmc.base_hmc import BaseHMC
66
import pymc3
7+
from pymc3.theanof import floatX
78
from .checks import close_to
9+
from .helpers import select_by_precision
810

911

1012
def test_leapfrog_reversible():
@@ -13,17 +15,16 @@ def test_leapfrog_reversible():
1315
step = BaseHMC(vars=model.vars, model=model)
1416
bij = DictToArrayBijection(step.ordering, start)
1517
q0 = bij.map(start)
16-
p0 = np.ones(n) * .05
17-
18+
p0 = floatX(np.ones(n) * .05)
19+
precision = select_by_precision(float64=1E-8, float32=1E-5)
1820
for epsilon in [.01, .1, 1.2]:
1921
for n_steps in [1, 2, 3, 4, 20]:
2022

2123
q, p = q0, p0
22-
q, p, _ = step.leapfrog(q, p, np.array(epsilon), np.array(n_steps, dtype='int32'))
23-
q, p, _ = step.leapfrog(q, -p, np.array(epsilon), np.array(n_steps, dtype='int32'))
24-
25-
close_to(q, q0, 1e-8, str((n_steps, epsilon)))
26-
close_to(-p, p0, 1e-8, str((n_steps, epsilon)))
24+
q, p, _ = step.leapfrog(q, p, floatX(np.array(epsilon)), np.array(n_steps, dtype='int32'))
25+
q, p, _ = step.leapfrog(q, -p, floatX(np.array(epsilon)), np.array(n_steps, dtype='int32'))
26+
close_to(q, q0, precision, str((n_steps, epsilon)))
27+
close_to(-p, p0, precision, str((n_steps, epsilon)))
2728

2829

2930
def test_leapfrog_reversible_single():
@@ -36,7 +37,8 @@ def test_leapfrog_reversible_single():
3637
for method, step in zip(integrators, steps):
3738
bij = DictToArrayBijection(step.ordering, start)
3839
q0 = bij.map(start)
39-
p0 = np.ones(n) * .05
40+
p0 = floatX(np.ones(n) * .05)
41+
precision = select_by_precision(float64=1E-8, float32=1E-5)
4042
for epsilon in [0.01, 0.1, 1.2]:
4143
for n_steps in [1, 2, 3, 4, 20]:
4244
dlogp0 = step.dlogp(q0)
@@ -46,13 +48,13 @@ def test_leapfrog_reversible_single():
4648

4749
energy = step.compute_energy(q, p)
4850
for _ in range(n_steps):
49-
q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
51+
q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, floatX(np.array(epsilon)))
5052
p = -p
5153
for _ in range(n_steps):
52-
q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
54+
q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, floatX(np.array(epsilon)))
5355

54-
close_to(q, q0, 1e-8, str(('q', method, n_steps, epsilon)))
55-
close_to(-p, p0, 1e-8, str(('p', method, n_steps, epsilon)))
56+
close_to(q, q0, precision, str(('q', method, n_steps, epsilon)))
57+
close_to(-p, p0, precision, str(('p', method, n_steps, epsilon)))
5658

5759

5860
def test_nuts_tuning():

0 commit comments

Comments
 (0)