|
27 | 27 | import scipy.stats
|
28 | 28 |
|
29 | 29 |
|
30 |
| -def select_decimal(float64, float32): |
| 30 | +def select_by_precision(float64, float32): |
31 | 31 | """Helper function to choose reasonable decimal cutoffs for different floatX modes."""
|
32 | 32 | decimal = float64 if theano.config.floatX == "float64" else float32
|
33 | 33 | return decimal
|
@@ -314,7 +314,7 @@ def check_logp(self, model, value, domain, paramdomains, logp_reference):
|
314 | 314 | logp = model.fastlogp
|
315 | 315 | for pt in product(domains, n_samples=100):
|
316 | 316 | pt = Point(pt, model=model)
|
317 |
| - assert_almost_equal(logp(pt), logp_reference(pt), decimal=select_decimal(float64=6, float32=2), err_msg=str(pt)) |
| 317 | + assert_almost_equal(logp(pt), logp_reference(pt), decimal=select_by_precision(float64=6, float32=2), err_msg=str(pt)) |
318 | 318 |
|
319 | 319 | def check_int_to_1(self, model, value, domain, paramdomains):
|
320 | 320 | pdf = model.fastfn(exp(model.logpt))
|
@@ -350,7 +350,7 @@ def wrapped_logp(x):
|
350 | 350 | for pt in product(domains, n_samples=100):
|
351 | 351 | pt = Point(pt, model=model)
|
352 | 352 | pt = bij.map(pt)
|
353 |
| - assert_almost_equal(dlogp(pt), ndlogp(pt), decimal=select_decimal(float64=6, float32=4), err_msg=str(pt)) |
| 353 | + assert_almost_equal(dlogp(pt), ndlogp(pt), decimal=select_by_precision(float64=6, float32=4), err_msg=str(pt)) |
354 | 354 |
|
355 | 355 | def checkd(self, distfam, valuedomain, vardomains, checks=None, extra_args={}):
|
356 | 356 | if checks is None:
|
@@ -423,7 +423,7 @@ def test_wald(self, value, mu, lam, phi, alpha, logp):
|
423 | 423 | with Model() as model:
|
424 | 424 | Wald('wald', mu=mu, lam=lam, phi=phi, alpha=alpha, transform=None)
|
425 | 425 | pt = {'wald': value}
|
426 |
| - assert_almost_equal(model.fastlogp(pt), logp, decimal=select_decimal(float64=6, float32=1), err_msg=str(pt)) |
| 426 | + assert_almost_equal(model.fastlogp(pt), logp, decimal=select_by_precision(float64=6, float32=1), err_msg=str(pt)) |
427 | 427 |
|
428 | 428 | def test_beta(self):
|
429 | 429 | self.pymc3_matches_scipy(Beta, Unit, {'alpha': Rplus, 'beta': Rplus},
|
@@ -567,7 +567,7 @@ def test_lkj(self, x, n, p, lp):
|
567 | 567 | LKJCorr('lkj', n=n, p=p, transform=None)
|
568 | 568 |
|
569 | 569 | pt = {'lkj': x}
|
570 |
| - assert_almost_equal(model.fastlogp(pt), lp, decimal=select_decimal(float64=6, float32=4), err_msg=str(pt)) |
| 570 | + assert_almost_equal(model.fastlogp(pt), lp, decimal=select_by_precision(float64=6, float32=4), err_msg=str(pt)) |
571 | 571 |
|
572 | 572 | @parameterized.expand([(2,), (3,)])
|
573 | 573 | def test_dirichlet(self, n):
|
|
0 commit comments