@@ -575,28 +575,6 @@ def check_logcdf(
575
575
err_msg = str (pt ),
576
576
)
577
577
578
- def check_selfconsistency_discrete_logcdf (
579
- self , distribution , domain , paramdomains , decimal = None , n_samples = 100
580
- ):
581
- """
582
- Check that logcdf of discrete distributions matches sum of logps up to value
583
- """
584
- domains = paramdomains .copy ()
585
- domains ["value" ] = domain
586
- if decimal is None :
587
- decimal = select_by_precision (float64 = 6 , float32 = 3 )
588
- for pt in product (domains , n_samples = n_samples ):
589
- params = dict (pt )
590
- value = params .pop ("value" )
591
- values = np .arange (domain .lower , value + 1 )
592
- dist = distribution .dist (** params )
593
- assert_almost_equal (
594
- dist .logcdf (value ).tag .test_value ,
595
- logsumexp (dist .logp (values ), keepdims = False ).tag .test_value ,
596
- decimal = decimal ,
597
- err_msg = str (pt ),
598
- )
599
-
600
578
# Test that values below domain evaluate to -np.inf
601
579
if np .isfinite (domain .lower ):
602
580
below_domain = domain .lower - 1
@@ -607,7 +585,9 @@ def check_selfconsistency_discrete_logcdf(
607
585
)
608
586
609
587
# Test that values above domain evaluate to 0
610
- if np .isfinite (domain .upper ):
588
+ # Natural domains do not have inf as the upper edge, but should also be ignored
589
+ not_nat_domain = domain not in (NatSmall , Nat , NatBig , PosNat )
590
+ if not_nat_domain and np .isfinite (domain .upper ):
611
591
above_domain = domain .upper + 1
612
592
assert_equal (
613
593
dist .logcdf (above_domain ).tag .test_value ,
@@ -624,6 +604,28 @@ def check_selfconsistency_discrete_logcdf(
624
604
):
625
605
raise
626
606
607
+ def check_selfconsistency_discrete_logcdf (
608
+ self , distribution , domain , paramdomains , decimal = None , n_samples = 100
609
+ ):
610
+ """
611
+ Check that logcdf of discrete distributions matches sum of logps up to value
612
+ """
613
+ domains = paramdomains .copy ()
614
+ domains ["value" ] = domain
615
+ if decimal is None :
616
+ decimal = select_by_precision (float64 = 6 , float32 = 3 )
617
+ for pt in product (domains , n_samples = n_samples ):
618
+ params = dict (pt )
619
+ value = params .pop ("value" )
620
+ values = np .arange (domain .lower , value + 1 )
621
+ dist = distribution .dist (** params )
622
+ assert_almost_equal (
623
+ dist .logcdf (value ).tag .test_value ,
624
+ logsumexp (dist .logp (values ), keepdims = False ).tag .test_value ,
625
+ decimal = decimal ,
626
+ err_msg = str (pt ),
627
+ )
628
+
627
629
def check_int_to_1 (self , model , value , domain , paramdomains ):
628
630
pdf = model .fastfn (exp (model .logpt ))
629
631
for pt in product (paramdomains , n_samples = 10 ):
0 commit comments