@@ -594,7 +594,6 @@ def check_logp(
594
594
n_samples = 100 ,
595
595
extra_args = None ,
596
596
scipy_args = None ,
597
- skip_params_fn = lambda x : False ,
598
597
):
599
598
"""
600
599
Generic test for PyMC3 logp methods
@@ -626,9 +625,6 @@ def check_logp(
626
625
the pymc3 distribution logp is calculated
627
626
scipy_args : Dictionary with extra arguments needed to call scipy logp method
628
627
Usually the same as extra_args
629
- skip_params_fn: Callable
630
- A function that takes a ``dict`` of the test points and returns a
631
- boolean indicating whether or not to perform the test.
632
628
"""
633
629
if decimal is None :
634
630
decimal = select_by_precision (float64 = 6 , float32 = 3 )
@@ -650,8 +646,6 @@ def logp_reference(args):
650
646
domains ["value" ] = domain
651
647
for pt in product (domains , n_samples = n_samples ):
652
648
pt = dict (pt )
653
- if skip_params_fn (pt ):
654
- continue
655
649
pt_d = self ._model_input_dict (model , param_vars , pt )
656
650
pt_logp = Point (pt_d , model = model )
657
651
pt_ref = Point (pt , filter_model_vars = False , model = model )
@@ -696,7 +690,7 @@ def check_logcdf(
696
690
n_samples = 100 ,
697
691
skip_paramdomain_inside_edge_test = False ,
698
692
skip_paramdomain_outside_edge_test = False ,
699
- skip_params_fn = lambda x : False ,
693
+ skip_nan = False ,
700
694
):
701
695
"""
702
696
Generic test for PyMC3 logcdf methods
@@ -737,9 +731,8 @@ def check_logcdf(
737
731
skip_paramdomain_outside_edge_test : Bool
738
732
Whether to run test 2., which checks that pymc3 distribution logcdf
739
733
returns -inf for invalid parameter values outside the supported domain edge
740
- skip_params_fn: Callable
741
- A function that takes a ``dict`` of the test points and returns a
742
- boolean indicating whether or not to perform the test.
734
+ skip_nan: Bool
735
+ Whether to skip comparison when pymc3 logcdf method evaluates to nan
743
736
744
737
Returns
745
738
-------
@@ -759,9 +752,6 @@ def check_logcdf(
759
752
760
753
for pt in product (domains , n_samples = n_samples ):
761
754
params = dict (pt )
762
- if skip_params_fn (params ):
763
- continue
764
-
765
755
scipy_eval = scipy_logcdf (** params )
766
756
767
757
value = params .pop ("value" )
@@ -770,6 +760,9 @@ def check_logcdf(
770
760
param_vars [param_name ].set_value (param_value )
771
761
pymc3_eval = pymc3_logcdf ({"value" : value })
772
762
763
+ if skip_nan and np .isnan (pymc3_eval ):
764
+ continue
765
+
773
766
params ["value" ] = value # for displaying in err_msg
774
767
assert_almost_equal (
775
768
pymc3_eval ,
@@ -851,7 +844,7 @@ def check_selfconsistency_discrete_logcdf(
851
844
paramdomains ,
852
845
decimal = None ,
853
846
n_samples = 100 ,
854
- skip_params_fn = lambda x : False ,
847
+ skip_nan = False ,
855
848
):
856
849
"""
857
850
Check that logcdf of discrete distributions matches sum of logps up to value
@@ -870,19 +863,25 @@ def check_selfconsistency_discrete_logcdf(
870
863
871
864
for pt in product (domains , n_samples = n_samples ):
872
865
params = dict (pt )
873
- if skip_params_fn (params ):
874
- continue
875
866
value = params .pop ("value" )
876
867
values = np .arange (domain .lower , value + 1 )
877
868
878
869
# Update shared parameter variables in logp/logcdf function
879
870
for param_name , param_value in params .items ():
880
871
param_vars [param_name ].set_value (param_value )
881
872
873
+ logcdf_eval = dist_logcdf ({"value" : value })
874
+ if skip_nan and np .isnan (logcdf_eval ):
875
+ continue
876
+
877
+ logp_logsumexp_eval = scipy .special .logsumexp (
878
+ [dist_logp ({"value" : value }) for value in values ]
879
+ )
880
+
882
881
with aesara .config .change_flags (mode = Mode ("py" )):
883
882
assert_almost_equal (
884
- dist_logcdf ({ "value" : value }) ,
885
- scipy . special . logsumexp ([ dist_logp ({ "value" : value }) for value in values ]) ,
883
+ logcdf_eval ,
884
+ logp_logsumexp_eval ,
886
885
decimal = decimal ,
887
886
err_msg = str (pt ),
888
887
)
@@ -1233,20 +1232,19 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n):
1233
1232
Nat ,
1234
1233
{"N" : NatSmall , "k" : NatSmall , "n" : NatSmall },
1235
1234
modified_scipy_hypergeom_logpmf ,
1236
- skip_params_fn = lambda x : x ["N" ] < x ["n" ] or x ["N" ] < x ["k" ],
1237
1235
)
1238
1236
self .check_logcdf (
1239
1237
HyperGeometric ,
1240
1238
Nat ,
1241
1239
{"N" : NatSmall , "k" : NatSmall , "n" : NatSmall },
1242
1240
modified_scipy_hypergeom_logcdf ,
1243
- skip_params_fn = lambda x : x [ "N" ] < x [ "n" ] or x [ "N" ] < x [ "k" ],
1241
+ skip_nan = True , # TODO: Remove once aesara/issues/461 is solved
1244
1242
)
1245
1243
self .check_selfconsistency_discrete_logcdf (
1246
1244
HyperGeometric ,
1247
1245
Nat ,
1248
1246
{"N" : NatSmall , "k" : NatSmall , "n" : NatSmall },
1249
- skip_params_fn = lambda x : x [ "N" ] < x [ "n" ] or x [ "N" ] < x [ "k" ],
1247
+ skip_nan = True , # TODO: Remove once aesara/issues/461 is solved
1250
1248
)
1251
1249
1252
1250
def test_negative_binomial (self ):
0 commit comments