@@ -634,6 +634,7 @@ def check_logp(
634
634
n_samples = 100 ,
635
635
extra_args = None ,
636
636
scipy_args = None ,
637
+ skip_paramdomain_outside_edge_test = False ,
637
638
):
638
639
"""
639
640
Generic test for PyMC logp methods
@@ -679,14 +680,39 @@ def logp_reference(args):
679
680
args .update (scipy_args )
680
681
return scipy_logp (** args )
681
682
683
+ def _model_input_dict (model , param_vars , pt ):
684
+ """Create a dict with only the necessary, transformed logp inputs."""
685
+ pt_d = {}
686
+ for k , v in pt .items ():
687
+ rv_var = model .named_vars .get (k )
688
+ nv = param_vars .get (k , rv_var )
689
+ nv = getattr (nv .tag , "value_var" , nv )
690
+
691
+ transform = getattr (nv .tag , "transform" , None )
692
+ if transform :
693
+ # todo: the compiled graph behind this should be cached and
694
+ # reused (if it isn't already).
695
+ v = transform .forward (rv_var , v ).eval ()
696
+
697
+ if nv .name in param_vars :
698
+ # update the shared parameter variables in `param_vars`
699
+ param_vars [nv .name ].set_value (v )
700
+ else :
701
+ # create an argument entry for the (potentially
702
+ # transformed) "value" variable
703
+ pt_d [nv .name ] = v
704
+
705
+ return pt_d
706
+
682
707
model , param_vars = build_model (pymc_dist , domain , paramdomains , extra_args )
683
708
logp_pymc = model .fastlogp_nojac
684
709
710
+ # Test supported value and parameters domain matches scipy
685
711
domains = paramdomains .copy ()
686
712
domains ["value" ] = domain
687
713
for pt in product (domains , n_samples = n_samples ):
688
714
pt = dict (pt )
689
- pt_d = self . _model_input_dict (model , param_vars , pt )
715
+ pt_d = _model_input_dict (model , param_vars , pt )
690
716
pt_logp = Point (pt_d , model = model )
691
717
pt_ref = Point (pt , filter_model_vars = False , model = model )
692
718
assert_almost_equal (
@@ -696,29 +722,58 @@ def logp_reference(args):
696
722
err_msg = str (pt ),
697
723
)
698
724
699
- def _model_input_dict (self , model , param_vars , pt ):
700
- """Create a dict with only the necessary, transformed logp inputs."""
701
- pt_d = {}
702
- for k , v in pt .items ():
703
- rv_var = model .named_vars .get (k )
704
- nv = param_vars .get (k , rv_var )
705
- nv = getattr (nv .tag , "value_var" , nv )
706
-
707
- transform = getattr (nv .tag , "transform" , None )
708
- if transform :
709
- # todo: the compiled graph behind this should be cached and
710
- # reused (if it isn't already).
711
- v = transform .forward (rv_var , v ).eval ()
712
-
713
- if nv .name in param_vars :
714
- # update the shared parameter variables in `param_vars`
715
- param_vars [nv .name ].set_value (v )
716
- else :
717
- # create an argument entry for the (potentially
718
- # transformed) "value" variable
719
- pt_d [nv .name ] = v
725
+ valid_value = domain .vals [0 ]
726
+ valid_params = {param : paramdomain .vals [0 ] for param , paramdomain in paramdomains .items ()}
727
+ valid_dist = pymc_dist .dist (** valid_params , ** extra_args )
720
728
721
- return pt_d
729
+ # Test pymc distribution raises ParameterValueError for scalar parameters outside
730
+ # the supported domain edges (excluding edges)
731
+ if not skip_paramdomain_outside_edge_test :
732
+ # Step1: collect potential invalid parameters
733
+ invalid_params = {param : [None , None ] for param in paramdomains }
734
+ for param , paramdomain in paramdomains .items ():
735
+ if np .ndim (paramdomain .lower ) != 0 :
736
+ continue
737
+ if np .isfinite (paramdomain .lower ):
738
+ invalid_params [param ][0 ] = paramdomain .lower - 1
739
+ if np .isfinite (paramdomain .upper ):
740
+ invalid_params [param ][1 ] = paramdomain .upper + 1
741
+
742
+ # Step2: test invalid parameters, one a time
743
+ for invalid_param , invalid_edges in invalid_params .items ():
744
+ for invalid_edge in invalid_edges :
745
+ if invalid_edge is None :
746
+ continue
747
+ test_params = valid_params .copy () # Shallow copy should be okay
748
+ test_params [invalid_param ] = at .as_tensor_variable (invalid_edge )
749
+ # We need to remove `Assert`s introduced by checks like
750
+ # `assert_negative_support` and disable test values;
751
+ # otherwise, we won't be able to create the `RandomVariable`
752
+ with aesara .config .change_flags (compute_test_value = "off" ):
753
+ invalid_dist = pymc_dist .dist (** test_params , ** extra_args )
754
+ with aesara .config .change_flags (mode = Mode ("py" )):
755
+ with pytest .raises (ParameterValueError ):
756
+ logp (invalid_dist , valid_value ).eval ()
757
+ pytest .fail (f"test_params={ test_params } , valid_value={ valid_value } " )
758
+
759
+ # Test that values outside of scalar domain support evaluate to -np.inf
760
+ if np .ndim (domain .lower ) != 0 :
761
+ return
762
+ invalid_values = [None , None ]
763
+ if np .isfinite (domain .lower ):
764
+ invalid_values [0 ] = domain .lower - 1
765
+ if np .isfinite (domain .upper ):
766
+ invalid_values [1 ] = domain .upper + 1
767
+
768
+ for invalid_value in invalid_values :
769
+ if invalid_value is None :
770
+ continue
771
+ with aesara .config .change_flags (mode = Mode ("py" )):
772
+ assert_equal (
773
+ logp (valid_dist , invalid_value ).eval (),
774
+ - np .inf ,
775
+ err_msg = str (invalid_value ),
776
+ )
722
777
723
778
def check_logcdf (
724
779
self ,
@@ -826,7 +881,7 @@ def check_logcdf(
826
881
for invalid_edge in invalid_edges :
827
882
if invalid_edge is not None :
828
883
test_params = valid_params .copy () # Shallow copy should be okay
829
- test_params [invalid_param ] = invalid_edge
884
+ test_params [invalid_param ] = at . as_tensor_variable ( invalid_edge )
830
885
# We need to remove `Assert`s introduced by checks like
831
886
# `assert_negative_support` and disable test values;
832
887
# otherwise, we won't be able to create the
@@ -934,6 +989,7 @@ def test_uniform(self):
934
989
Runif ,
935
990
{"lower" : - Rplusunif , "upper" : Rplusunif },
936
991
lambda value , lower , upper : sp .uniform .logpdf (value , lower , upper - lower ),
992
+ skip_paramdomain_outside_edge_test = True ,
937
993
)
938
994
self .check_logcdf (
939
995
Uniform ,
@@ -954,6 +1010,7 @@ def test_triangular(self):
954
1010
Runif ,
955
1011
{"lower" : - Rplusunif , "c" : Runif , "upper" : Rplusunif },
956
1012
lambda value , c , lower , upper : sp .triang .logpdf (value , c - lower , lower , upper - lower ),
1013
+ skip_paramdomain_outside_edge_test = True ,
957
1014
)
958
1015
self .check_logcdf (
959
1016
Triangular ,
@@ -1007,6 +1064,7 @@ def test_discrete_unif(self):
1007
1064
Rdunif ,
1008
1065
{"lower" : - Rplusdunif , "upper" : Rplusdunif },
1009
1066
lambda value , lower , upper : sp .randint .logpmf (value , lower , upper + 1 ),
1067
+ skip_paramdomain_outside_edge_test = True ,
1010
1068
)
1011
1069
self .check_logcdf (
1012
1070
DiscreteUniform ,
@@ -1017,7 +1075,7 @@ def test_discrete_unif(self):
1017
1075
)
1018
1076
self .check_selfconsistency_discrete_logcdf (
1019
1077
DiscreteUniform ,
1020
- Rdunif ,
1078
+ Domain ([ - 10 , 0 , 10 ], "int64" ) ,
1021
1079
{"lower" : - Rplusdunif , "upper" : Rplusdunif },
1022
1080
)
1023
1081
# Custom logp / logcdf check for invalid parameters
@@ -1029,7 +1087,7 @@ def test_discrete_unif(self):
1029
1087
logcdf (invalid_dist , 2 ).eval ()
1030
1088
1031
1089
def test_flat (self ):
1032
- self .check_logp (Flat , Runif , {}, lambda value : 0 )
1090
+ self .check_logp (Flat , R , {}, lambda value : 0 )
1033
1091
with Model ():
1034
1092
x = Flat ("a" )
1035
1093
self .check_logcdf (Flat , R , {}, lambda value : np .log (0.5 ))
@@ -1074,6 +1132,7 @@ def scipy_logp(value, mu, sigma, lower, upper):
1074
1132
{"mu" : R , "sigma" : Rplusbig , "lower" : - Rplusbig , "upper" : Rplusbig },
1075
1133
scipy_logp ,
1076
1134
decimal = select_by_precision (float64 = 6 , float32 = 1 ),
1135
+ skip_paramdomain_outside_edge_test = True ,
1077
1136
)
1078
1137
1079
1138
self .check_logp (
@@ -1082,6 +1141,7 @@ def scipy_logp(value, mu, sigma, lower, upper):
1082
1141
{"mu" : R , "sigma" : Rplusbig , "upper" : Rplusbig },
1083
1142
functools .partial (scipy_logp , lower = - np .inf ),
1084
1143
decimal = select_by_precision (float64 = 6 , float32 = 1 ),
1144
+ skip_paramdomain_outside_edge_test = True ,
1085
1145
)
1086
1146
1087
1147
self .check_logp (
@@ -1090,6 +1150,7 @@ def scipy_logp(value, mu, sigma, lower, upper):
1090
1150
{"mu" : R , "sigma" : Rplusbig , "lower" : - Rplusbig },
1091
1151
functools .partial (scipy_logp , upper = np .inf ),
1092
1152
decimal = select_by_precision (float64 = 6 , float32 = 1 ),
1153
+ skip_paramdomain_outside_edge_test = True ,
1093
1154
)
1094
1155
1095
1156
def test_half_normal (self ):
@@ -1679,13 +1740,13 @@ def test_discrete_weibull(self):
1679
1740
self .check_logp (
1680
1741
DiscreteWeibull ,
1681
1742
Nat ,
1682
- {"q" : Unit , "beta" : Rplusdunif },
1743
+ {"q" : Unit , "beta" : NatSmall },
1683
1744
discrete_weibull_logpmf ,
1684
1745
)
1685
1746
self .check_selfconsistency_discrete_logcdf (
1686
1747
DiscreteWeibull ,
1687
1748
Nat ,
1688
- {"q" : Unit , "beta" : Rplusdunif },
1749
+ {"q" : Unit , "beta" : NatSmall },
1689
1750
)
1690
1751
1691
1752
def test_poisson (self ):
@@ -2088,7 +2149,7 @@ def test_wishart(self, n):
2088
2149
self .check_logp (
2089
2150
Wishart ,
2090
2151
PdMatrix (n ),
2091
- {"nu" : Domain ([3 , 4 , 2000 ] ), "V" : PdMatrix (n )},
2152
+ {"nu" : Domain ([0 , 3 , 4 , np . inf ], "int64" ), "V" : PdMatrix (n )},
2092
2153
lambda value , nu , V : scipy .stats .wishart .logpdf (value , np .int (nu ), V ),
2093
2154
)
2094
2155
@@ -2255,7 +2316,7 @@ def test_categorical_valid_p(self, p):
2255
2316
def test_categorical (self , n ):
2256
2317
self .check_logp (
2257
2318
Categorical ,
2258
- Domain (range (n ), dtype = "int64" , edges = (None , None )),
2319
+ Domain (range (n ), dtype = "int64" , edges = (0 , n )),
2259
2320
{"p" : Simplex (n )},
2260
2321
lambda value , p : categorical_logpdf (value , p ),
2261
2322
)
@@ -2368,8 +2429,8 @@ def test_ex_gaussian_cdf_outside_edges(self):
2368
2429
def test_vonmises (self ):
2369
2430
self .check_logp (
2370
2431
VonMises ,
2371
- R ,
2372
- {"mu" : Circ , "kappa" : Rplus },
2432
+ Circ ,
2433
+ {"mu" : R , "kappa" : Rplus },
2373
2434
lambda value , mu , kappa : floatX (sp .vonmises .logpdf (value , kappa , loc = mu )),
2374
2435
)
2375
2436
0 commit comments