@@ -596,7 +596,6 @@ def check_logp(
596
596
n_samples = 100 ,
597
597
extra_args = None ,
598
598
scipy_args = None ,
599
- skip_params_fn = lambda x : False ,
600
599
):
601
600
"""
602
601
Generic test for PyMC3 logp methods
@@ -628,9 +627,6 @@ def check_logp(
628
627
the pymc3 distribution logp is calculated
629
628
scipy_args : Dictionary with extra arguments needed to call scipy logp method
630
629
Usually the same as extra_args
631
- skip_params_fn: Callable
632
- A function that takes a ``dict`` of the test points and returns a
633
- boolean indicating whether or not to perform the test.
634
630
"""
635
631
if decimal is None :
636
632
decimal = select_by_precision (float64 = 6 , float32 = 3 )
@@ -652,8 +648,6 @@ def logp_reference(args):
652
648
domains ["value" ] = domain
653
649
for pt in product (domains , n_samples = n_samples ):
654
650
pt = dict (pt )
655
- if skip_params_fn (pt ):
656
- continue
657
651
pt_d = self ._model_input_dict (model , param_vars , pt )
658
652
pt_logp = Point (pt_d , model = model )
659
653
pt_ref = Point (pt , filter_model_vars = False , model = model )
@@ -698,7 +692,6 @@ def check_logcdf(
698
692
n_samples = 100 ,
699
693
skip_paramdomain_inside_edge_test = False ,
700
694
skip_paramdomain_outside_edge_test = False ,
701
- skip_params_fn = lambda x : False ,
702
695
):
703
696
"""
704
697
Generic test for PyMC3 logcdf methods
@@ -739,9 +732,6 @@ def check_logcdf(
739
732
skip_paramdomain_outside_edge_test : Bool
740
733
Whether to run test 2., which checks that pymc3 distribution logcdf
741
734
returns -inf for invalid parameter values outside the supported domain edge
742
- skip_params_fn: Callable
743
- A function that takes a ``dict`` of the test points and returns a
744
- boolean indicating whether or not to perform the test.
745
735
746
736
Returns
747
737
-------
@@ -761,9 +751,6 @@ def check_logcdf(
761
751
762
752
for pt in product (domains , n_samples = n_samples ):
763
753
params = dict (pt )
764
- if skip_params_fn (params ):
765
- continue
766
-
767
754
scipy_eval = scipy_logcdf (** params )
768
755
769
756
value = params .pop ("value" )
@@ -853,7 +840,6 @@ def check_selfconsistency_discrete_logcdf(
853
840
paramdomains ,
854
841
decimal = None ,
855
842
n_samples = 100 ,
856
- skip_params_fn = lambda x : False ,
857
843
):
858
844
"""
859
845
Check that logcdf of discrete distributions matches sum of logps up to value
@@ -872,8 +858,6 @@ def check_selfconsistency_discrete_logcdf(
872
858
873
859
for pt in product (domains , n_samples = n_samples ):
874
860
params = dict (pt )
875
- if skip_params_fn (params ):
876
- continue
877
861
value = params .pop ("value" )
878
862
values = np .arange (domain .lower , value + 1 )
879
863
@@ -1256,20 +1240,17 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n):
1256
1240
Nat ,
1257
1241
{"N" : NatSmall , "k" : NatSmall , "n" : NatSmall },
1258
1242
modified_scipy_hypergeom_logpmf ,
1259
- skip_params_fn = lambda x : x ["N" ] < x ["n" ] or x ["N" ] < x ["k" ],
1260
1243
)
1261
1244
self .check_logcdf (
1262
1245
HyperGeometric ,
1263
1246
Nat ,
1264
1247
{"N" : NatSmall , "k" : NatSmall , "n" : NatSmall },
1265
1248
modified_scipy_hypergeom_logcdf ,
1266
- skip_params_fn = lambda x : x ["N" ] < x ["n" ] or x ["N" ] < x ["k" ],
1267
1249
)
1268
1250
self .check_selfconsistency_discrete_logcdf (
1269
1251
HyperGeometric ,
1270
1252
Nat ,
1271
1253
{"N" : NatSmall , "k" : NatSmall , "n" : NatSmall },
1272
- skip_params_fn = lambda x : x ["N" ] < x ["n" ] or x ["N" ] < x ["k" ],
1273
1254
)
1274
1255
1275
1256
def test_negative_binomial (self ):
0 commit comments