Skip to content

Commit 752c184

Browse files
committed
Float32 skipif on Beta and StudentT logcdf tests
1 parent 8eae85f commit 752c184

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

pymc3/tests/test_distributions.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,7 @@ def test_wald_logp_custom_points(self, value, mu, lam, phi, alpha, logp):
11511151
decimals = select_by_precision(float64=6, float32=1)
11521152
assert_almost_equal(model.fastlogp(pt), logp, decimal=decimals, err_msg=str(pt))
11531153

1154-
def test_beta(self):
1154+
def test_beta_logp(self):
11551155
self.check_logp(
11561156
Beta,
11571157
Unit,
@@ -1164,12 +1164,17 @@ def test_beta(self):
11641164
{"mu": Unit, "sigma": Rplus},
11651165
beta_mu_sigma,
11661166
)
1167+
1168+
@pytest.mark.skipif(
1169+
condition=(aesara.config.floatX == "float32"),
1170+
reason="Fails on float32 due to numerical issues",
1171+
)
1172+
def test_beta_logcdf(self):
11671173
self.check_logcdf(
11681174
Beta,
11691175
Unit,
11701176
{"alpha": Rplus, "beta": Rplus},
11711177
lambda value, alpha, beta: sp.beta.logcdf(value, alpha, beta),
1172-
decimal=select_by_precision(float64=5, float32=3),
11731178
)
11741179

11751180
def test_kumaraswamy(self):
@@ -1373,7 +1378,7 @@ def test_lognormal(self):
13731378
lambda value, mu, sigma: sp.lognorm.logcdf(value, sigma, 0, np.exp(mu)),
13741379
)
13751380

1376-
def test_t(self):
1381+
def test_studentt_logp(self):
13771382
self.check_logp(
13781383
StudentT,
13791384
R,
@@ -1386,21 +1391,24 @@ def test_t(self):
13861391
{"nu": Rplus, "mu": R, "sigma": Rplus},
13871392
lambda value, nu, mu, sigma: sp.t.logpdf(value, nu, mu, sigma),
13881393
)
1394+
1395+
@pytest.mark.skipif(
1396+
condition=(aesara.config.floatX == "float32"),
1397+
reason="Fails on float32 due to numerical issues",
1398+
)
1399+
def test_studentt_logcdf(self):
13891400
self.check_logcdf(
13901401
StudentT,
13911402
R,
13921403
{"nu": Rplus, "mu": R, "lam": Rplus},
13931404
lambda value, nu, mu, lam: sp.t.logcdf(value, nu, mu, lam ** -0.5),
13941405
)
1395-
# TODO: reenable when PR #4736 is merged
1396-
"""
13971406
self.check_logcdf(
13981407
StudentT,
13991408
R,
14001409
{"nu": Rplus, "mu": R, "sigma": Rplus},
14011410
lambda value, nu, mu, sigma: sp.t.logcdf(value, nu, mu, sigma),
14021411
)
1403-
"""
14041412

14051413
def test_cauchy(self):
14061414
self.check_logp(

0 commit comments

Comments
 (0)