Skip to content

Commit a1e753f

Browse files
committed
Small fix Triangular logp and logcdf methods
1 parent 41a25d5 commit a1e753f

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

pymc3/distributions/continuous.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from pymc3.distributions import transforms
3030
from pymc3.distributions.dist_math import (
3131
SplineWrapper,
32-
alltrue_elemwise,
3332
betaln,
3433
bound,
3534
clipped_beta_rvs,
@@ -3649,18 +3648,14 @@ def logp(self, value):
36493648
c = self.c
36503649
lower = self.lower
36513650
upper = self.upper
3652-
return tt.switch(
3653-
alltrue_elemwise([lower <= value, value < c]),
3654-
tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))),
3651+
return bound(
36553652
tt.switch(
3656-
tt.eq(value, c),
3657-
tt.log(2 / (upper - lower)),
3658-
tt.switch(
3659-
alltrue_elemwise([c < value, value <= upper]),
3660-
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),
3661-
np.inf,
3662-
),
3653+
tt.lt(value, c),
3654+
tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))),
3655+
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),
36633656
),
3657+
lower <= value,
3658+
value <= upper,
36643659
)
36653660

36663661
def logcdf(self, value):
@@ -3678,17 +3673,24 @@ def logcdf(self, value):
36783673
-------
36793674
TensorVariable
36803675
"""
3681-
l = self.lower
3682-
u = self.upper
36833676
c = self.c
3684-
return tt.switch(
3685-
tt.le(value, l),
3686-
-np.inf,
3677+
lower = self.lower
3678+
upper = self.upper
3679+
return bound(
36873680
tt.switch(
3688-
tt.le(value, c),
3689-
tt.log(((value - l) ** 2) / ((u - l) * (c - l))),
3690-
tt.switch(tt.lt(value, u), tt.log1p(-((u - value) ** 2) / ((u - l) * (u - c))), 0),
3681+
tt.le(value, lower),
3682+
-np.inf,
3683+
tt.switch(
3684+
tt.le(value, c),
3685+
tt.log(((value - lower) ** 2) / ((upper - lower) * (c - lower))),
3686+
tt.switch(
3687+
tt.lt(value, upper),
3688+
tt.log1p(-((upper - value) ** 2) / ((upper - lower) * (upper - c))),
3689+
0,
3690+
),
3691+
),
36913692
),
3693+
lower <= upper,
36923694
)
36933695

36943696

0 commit comments

Comments
 (0)