Skip to content

Commit 8c51062

Browse files
committed
Ignore finite upper limit in Nat domains.
Move new checks to `check_logcdf`.
1 parent 7c3fb12 commit 8c51062

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

pymc3/tests/test_distributions.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -575,28 +575,6 @@ def check_logcdf(
575575
err_msg=str(pt),
576576
)
577577

578-
def check_selfconsistency_discrete_logcdf(
579-
self, distribution, domain, paramdomains, decimal=None, n_samples=100
580-
):
581-
"""
582-
Check that logcdf of discrete distributions matches sum of logps up to value
583-
"""
584-
domains = paramdomains.copy()
585-
domains["value"] = domain
586-
if decimal is None:
587-
decimal = select_by_precision(float64=6, float32=3)
588-
for pt in product(domains, n_samples=n_samples):
589-
params = dict(pt)
590-
value = params.pop("value")
591-
values = np.arange(domain.lower, value + 1)
592-
dist = distribution.dist(**params)
593-
assert_almost_equal(
594-
dist.logcdf(value).tag.test_value,
595-
logsumexp(dist.logp(values), keepdims=False).tag.test_value,
596-
decimal=decimal,
597-
err_msg=str(pt),
598-
)
599-
600578
# Test that values below domain evaluate to -np.inf
601579
if np.isfinite(domain.lower):
602580
below_domain = domain.lower - 1
@@ -607,7 +585,9 @@ def check_selfconsistency_discrete_logcdf(
607585
)
608586

609587
# Test that values above domain evaluate to 0
610-
if np.isfinite(domain.upper):
588+
# Natural domains do not have inf as the upper edge, but should also be ignored
589+
not_nat_domain = domain not in (NatSmall, Nat, NatBig, PosNat)
590+
if not_nat_domain and np.isfinite(domain.upper):
611591
above_domain = domain.upper + 1
612592
assert_equal(
613593
dist.logcdf(above_domain).tag.test_value,
@@ -624,6 +604,28 @@ def check_selfconsistency_discrete_logcdf(
624604
):
625605
raise
626606

607+
def check_selfconsistency_discrete_logcdf(
608+
self, distribution, domain, paramdomains, decimal=None, n_samples=100
609+
):
610+
"""
611+
Check that logcdf of discrete distributions matches sum of logps up to value
612+
"""
613+
domains = paramdomains.copy()
614+
domains["value"] = domain
615+
if decimal is None:
616+
decimal = select_by_precision(float64=6, float32=3)
617+
for pt in product(domains, n_samples=n_samples):
618+
params = dict(pt)
619+
value = params.pop("value")
620+
values = np.arange(domain.lower, value + 1)
621+
dist = distribution.dist(**params)
622+
assert_almost_equal(
623+
dist.logcdf(value).tag.test_value,
624+
logsumexp(dist.logp(values), keepdims=False).tag.test_value,
625+
decimal=decimal,
626+
err_msg=str(pt),
627+
)
628+
627629
def check_int_to_1(self, model, value, domain, paramdomains):
628630
pdf = model.fastfn(exp(model.logpt))
629631
for pt in product(paramdomains, n_samples=10):

0 commit comments

Comments
 (0)