@@ -760,31 +760,40 @@ def logp(
760
760
-------
761
761
TensorVariable
762
762
"""
763
- unbounded_lower = isinstance (lower , TensorConstant ) and np .all (lower .value == - np .inf )
764
- unbounded_upper = isinstance (upper , TensorConstant ) and np .all (upper .value == np .inf )
763
+ is_lower_bounded = not (
764
+ isinstance (lower , TensorConstant ) and np .all (np .isneginf (lower .value ))
765
+ )
766
+ is_upper_bounded = not (isinstance (upper , TensorConstant ) and np .all (np .isinf (upper .value )))
765
767
766
- if not unbounded_lower and not unbounded_upper :
768
+ if is_lower_bounded and is_upper_bounded :
767
769
lcdf_a = normal_lcdf (mu , sigma , lower )
768
770
lcdf_b = normal_lcdf (mu , sigma , upper )
769
771
lsf_a = normal_lccdf (mu , sigma , lower )
770
772
lsf_b = normal_lccdf (mu , sigma , upper )
771
773
norm = at .switch (lower > 0 , logdiffexp (lsf_a , lsf_b ), logdiffexp (lcdf_b , lcdf_a ))
772
- elif not unbounded_lower :
774
+ elif is_lower_bounded :
773
775
norm = normal_lccdf (mu , sigma , lower )
774
- elif not unbounded_upper :
776
+ elif is_upper_bounded :
775
777
norm = normal_lcdf (mu , sigma , upper )
776
778
else :
777
779
norm = 0.0
778
780
779
781
logp = _logprob (normal , (value ,), None , None , None , mu , sigma ) - norm
780
- bounds = []
781
- if not unbounded_lower :
782
- bounds .append (value >= lower )
783
- if not unbounded_upper :
784
- bounds .append (value <= upper )
785
- if not unbounded_lower and not unbounded_upper :
786
- bounds .append (lower <= upper )
787
- return check_parameters (logp , * bounds )
782
+
783
+ if is_lower_bounded :
784
+ logp = at .switch (value < lower , - np .inf , logp )
785
+
786
+ if is_upper_bounded :
787
+ logp = at .switch (value > upper , - np .inf , logp )
788
+
789
+ if is_lower_bounded and is_upper_bounded :
790
+ logp = check_parameters (
791
+ logp ,
792
+ at .le (lower , upper ),
793
+ msg = "lower_bound <= upper_bound" ,
794
+ )
795
+
796
+ return logp
788
797
789
798
790
799
@_default_transform .register (TruncatedNormal )
0 commit comments