Skip to content

Commit a21fafa

Browse files
authored
Update log1mexp and remove redundant local reimplementations in the library (#4394)
* Fix cutoff value in `log1mexp` and redundant reimplementation in `Exponential.logcdf()` * Get more digits :) * Remove redundant reimplementation in Weibull * Tiny rewrite to use `|` instead of `tt.or_`. Remove TODO comments. * retrigger checks
1 parent e2ce815 commit a21fafa

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

pymc3/distributions/continuous.py

+4-16
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from pymc3.distributions.distribution import Continuous, draw_values, generate_samples
4747
from pymc3.distributions.special import log_i0
48-
from pymc3.math import invlogit, logdiffexp, logit
48+
from pymc3.math import invlogit, log1mexp, logdiffexp, logit
4949
from pymc3.theanof import floatX
5050

5151
__all__ = [
@@ -1513,12 +1513,6 @@ def logcdf(self, value):
15131513
Compute the log of cumulative distribution function for the Exponential distribution
15141514
at the specified value.
15151515
1516-
References
1517-
----------
1518-
.. [Machler2012] Martin Mächler (2012).
1519-
"Accurately computing :math:`\log(1-\exp(-\mid a \mid))` Assessed by the Rmpfr
1520-
package"
1521-
15221516
Parameters
15231517
----------
15241518
value: numeric
@@ -1533,9 +1527,9 @@ def logcdf(self, value):
15331527
lam = self.lam
15341528
a = lam * value
15351529
return tt.switch(
1536-
tt.le(value, 0.0),
1530+
tt.le(value, 0.0) | tt.le(lam, 0),
15371531
-np.inf,
1538-
tt.switch(tt.le(a, tt.log(2.0)), tt.log(-tt.expm1(-a)), tt.log1p(-tt.exp(-a))),
1532+
log1mexp(a),
15391533
)
15401534

15411535

@@ -2806,12 +2800,6 @@ def logcdf(self, value):
28062800
Compute the log of the cumulative distribution function for Weibull distribution
28072801
at the specified value.
28082802
2809-
References
2810-
----------
2811-
.. [Machler2012] Martin Mächler (2012).
2812-
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr
2813-
package"
2814-
28152803
Parameters
28162804
----------
28172805
value: numeric
@@ -2828,7 +2816,7 @@ def logcdf(self, value):
28282816
return tt.switch(
28292817
tt.le(value, 0.0),
28302818
-np.inf,
2831-
tt.switch(tt.le(a, tt.log(2.0)), tt.log(-tt.expm1(-a)), tt.log1p(-tt.exp(-a))),
2819+
log1mexp(a),
28322820
)
28332821

28342822

pymc3/math.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,21 @@ def log1pexp(x):
219219

220220

221221
def log1mexp(x):
222-
"""Return log(1 - exp(-x)).
222+
r"""Return log(1 - exp(-x)).
223223
224224
This function is numerically more stable than the naive approach.
225225
226226
For details, see
227227
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
228+
229+
References
230+
----------
231+
.. [Machler2012] Martin Mächler (2012).
232+
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr
233+
package"
234+
228235
"""
229-
return tt.switch(tt.lt(x, 0.683), tt.log(-tt.expm1(-x)), tt.log1p(-tt.exp(-x)))
236+
return tt.switch(tt.lt(x, 0.6931471805599453), tt.log(-tt.expm1(-x)), tt.log1p(-tt.exp(-x)))
230237

231238

232239
def log1mexp_numpy(x):
@@ -235,7 +242,7 @@ def log1mexp_numpy(x):
235242
For details, see
236243
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
237244
"""
238-
return np.where(x < 0.683, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))
245+
return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))
239246

240247

241248
def flatten_list(tensors):

0 commit comments

Comments
 (0)