@@ -802,6 +802,10 @@ def test_uniform(self):
802
802
lambda value , lower , upper : sp .uniform .logcdf (value , lower , upper - lower ),
803
803
skip_paramdomain_outside_edge_test = True ,
804
804
)
805
+ # Custom logp / logcdf check for invalid parameters
806
+ invalid_dist = Uniform .dist (lower = 1 , upper = 0 )
807
+ assert invalid_dist .logp (0.5 ).tag .test_value == - np .inf
808
+ assert invalid_dist .logcdf (2 ).tag .test_value == - np .inf
805
809
806
810
def test_triangular (self ):
807
811
self .check_logp (
@@ -817,6 +821,14 @@ def test_triangular(self):
817
821
lambda value , c , lower , upper : sp .triang .logcdf (value , c - lower , lower , upper - lower ),
818
822
skip_paramdomain_outside_edge_test = True ,
819
823
)
824
+ # Custom logp check for invalid value
825
+ valid_dist = Triangular .dist (lower = 0 , upper = 1 , c = 2.0 )
826
+ assert np .all (valid_dist .logp (np .array ([1.9 , 2.0 , 2.1 ])).tag .test_value == - np .inf )
827
+
828
+ # Custom logp / logcdf check for invalid parameters
829
+ invalid_dist = Triangular .dist (lower = 1 , upper = 0 , c = 2.0 )
830
+ assert invalid_dist .logp (0.5 ).tag .test_value == - np .inf
831
+ assert invalid_dist .logcdf (2 ).tag .test_value == - np .inf
820
832
821
833
def test_bound_normal (self ):
822
834
PositiveNormal = Bound (Normal , lower = 0.0 )
@@ -850,6 +862,10 @@ def test_discrete_unif(self):
850
862
Rdunif ,
851
863
{"lower" : - Rplusdunif , "upper" : Rplusdunif },
852
864
)
865
+ # Custom logp / logcdf check for invalid parameters
866
+ invalid_dist = DiscreteUniform .dist (lower = 1 , upper = 0 )
867
+ assert invalid_dist .logp (0.5 ).tag .test_value == - np .inf
868
+ assert invalid_dist .logcdf (2 ).tag .test_value == - np .inf
853
869
854
870
def test_flat (self ):
855
871
self .check_logp (Flat , Runif , {}, lambda value : 0 )
0 commit comments