Skip to content

Commit e88c2f9

Browse files
authored
Set different almost equal tolerance depending on floatX (#3980)
1 parent d0de763 commit e88c2f9

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

pymc3/tests/test_model.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pymc3.distributions import HalfCauchy, Normal, transforms
2424
from pymc3 import Potential, Deterministic
2525
from pymc3.model import ValueGradFunction
26+
from .helpers import select_by_precision
2627

2728

2829
class NewModel(pm.Model):
@@ -192,17 +193,33 @@ def test_matrix_multiplication():
192193
tune=0,
193194
compute_convergence_checks=False,
194195
progressbar=False)
196+
decimal = select_by_precision(7, 5)
195197
for point in posterior.points():
196-
npt.assert_almost_equal(point['matrix'] @ point['transformed'],
197-
point['rv_rv'])
198-
npt.assert_almost_equal(np.ones((2, 2)) @ point['transformed'],
199-
point['np_rv'])
200-
npt.assert_almost_equal(point['matrix'] @ np.ones(2),
201-
point['rv_np'])
202-
npt.assert_almost_equal(point['matrix'] @ point['rv_rv'],
203-
point['rv_det'])
204-
npt.assert_almost_equal(point['rv_rv'] @ point['transformed'],
205-
point['det_rv'])
198+
npt.assert_almost_equal(
199+
point['matrix'] @ point['transformed'],
200+
point['rv_rv'],
201+
decimal=decimal,
202+
)
203+
npt.assert_almost_equal(
204+
np.ones((2, 2)) @ point['transformed'],
205+
point['np_rv'],
206+
decimal=decimal,
207+
)
208+
npt.assert_almost_equal(
209+
point['matrix'] @ np.ones(2),
210+
point['rv_np'],
211+
decimal=decimal,
212+
)
213+
npt.assert_almost_equal(
214+
point['matrix'] @ point['rv_rv'],
215+
point['rv_det'],
216+
decimal=decimal,
217+
)
218+
npt.assert_almost_equal(
219+
point['rv_rv'] @ point['transformed'],
220+
point['det_rv'],
221+
decimal=decimal,
222+
)
206223

207224

208225
def test_duplicate_vars():

0 commit comments

Comments
 (0)