Skip to content

Commit b14d303

Browse files
committed
Remove custom logaddexp, logsumexp, and log1pexp functions and deprecate log1mexp in favor of Aesara implementations.
Closes #4747
1 parent 0b1ecdb commit b14d303

File tree

6 files changed

+93
-101
lines changed

6 files changed

+93
-101
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
- The length of `dims` in the model is now tracked symbolically through `Model.dim_lengths` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)).
3131
- We now include `cloudpickle` as a required dependency, and no longer depend on `dill` (see [#4858](https://github.com/pymc-devs/pymc3/pull/4858)).
3232
- The `incomplete_beta` function in `pymc3.distributions.dist_math` was replaced by `aesara.tensor.betainc` (see [4857](https://github.com/pymc-devs/pymc3/pull/4857)).
33+
- `math.log1mexp` and `math.log1mexp_numpy` will expect negative inputs in the future. A `FutureWarning` is now raised unless `negative_input=True` is set (see [#4860](https://github.com/pymc-devs/pymc3/pull/4860)).
3334
- ...
3435

3536
## PyMC3 3.11.2 (14 March 2021)

pymc3/distributions/continuous.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
zvalue,
6868
)
6969
from pymc3.distributions.distribution import Continuous
70-
from pymc3.math import log1mexp, log1pexp, logdiffexp, logit
70+
from pymc3.math import logdiffexp, logit
7171
from pymc3.util import UNSET
7272

7373
__all__ = [
@@ -1095,7 +1095,7 @@ def logcdf(
10951095
return bound(
10961096
at.switch(
10971097
at.lt(value, np.inf),
1098-
a + log1pexp(b - a),
1098+
a + at.log1pexp(b - a),
10991099
0,
11001100
),
11011101
0 < value,
@@ -1370,7 +1370,7 @@ def logcdf(value, a, b):
13701370
-------
13711371
TensorVariable
13721372
"""
1373-
logcdf = log1mexp(-(b * at.log1p(-(value ** a))))
1373+
logcdf = at.log1mexp(b * at.log1p(-(value ** a)))
13741374
return bound(at.switch(value < 1, logcdf, 0), value >= 0, a > 0, b > 0)
13751375

13761376

@@ -1462,7 +1462,7 @@ def logcdf(value, mu):
14621462
"""
14631463
lam = at.inv(mu)
14641464
return bound(
1465-
log1mexp(lam * value),
1465+
at.log1mexp(-lam * value),
14661466
0 <= value,
14671467
0 <= lam,
14681468
)
@@ -2711,7 +2711,7 @@ def logcdf(value, alpha, beta):
27112711
"""
27122712
a = (value / beta) ** alpha
27132713
return bound(
2714-
log1mexp(a),
2714+
at.log1mexp(-a),
27152715
0 <= value,
27162716
0 < alpha,
27172717
0 < beta,
@@ -3662,7 +3662,7 @@ def logcdf(value, mu, s):
36623662
"""
36633663

36643664
return bound(
3665-
-log1pexp(-(value - mu) / s),
3665+
-at.log1pexp(-(value - mu) / s),
36663666
0 < s,
36673667
)
36683668

pymc3/distributions/discrete.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from pymc3.distributions.distribution import Discrete
4444
from pymc3.distributions.logprob import _logcdf, _logp
45-
from pymc3.math import log1mexp, logaddexp, logsumexp, sigmoid
45+
from pymc3.math import sigmoid
4646

4747
__all__ = [
4848
"Binomial",
@@ -279,7 +279,7 @@ def logcdf(value, n, alpha, beta):
279279
return bound(
280280
at.switch(
281281
at.lt(value, n),
282-
logsumexp(
282+
at.logsumexp(
283283
BetaBinomial.logp(at.arange(safe_lower, value + 1), n, alpha, beta),
284284
keepdims=False,
285285
),
@@ -826,7 +826,7 @@ def logcdf(value, p):
826826
"""
827827

828828
return bound(
829-
log1mexp(-at.log1p(-p) * value),
829+
at.log1mexp(at.log1p(-p) * value),
830830
0 <= value,
831831
0 <= p,
832832
p <= 1,
@@ -945,7 +945,7 @@ def logcdf(value, good, bad, n):
945945
return bound(
946946
at.switch(
947947
at.lt(value, n),
948-
logsumexp(
948+
at.logsumexp(
949949
HyperGeometric.logp(at.arange(safe_lower, value + 1), good, bad, n),
950950
keepdims=False,
951951
),
@@ -1300,7 +1300,7 @@ def logp(value, psi, theta):
13001300
logp_val = at.switch(
13011301
at.gt(value, 0),
13021302
at.log(psi) + _logp(poisson, value, {}, theta),
1303-
logaddexp(at.log1p(-psi), at.log(psi) - theta),
1303+
at.logaddexp(at.log1p(-psi), at.log(psi) - theta),
13041304
)
13051305

13061306
return bound(
@@ -1328,7 +1328,7 @@ def logcdf(value, psi, theta):
13281328
"""
13291329

13301330
return bound(
1331-
logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(poisson, value, {}, theta)),
1331+
at.logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(poisson, value, {}, theta)),
13321332
0 <= value,
13331333
0 <= psi,
13341334
psi <= 1,
@@ -1430,7 +1430,7 @@ def logp(value, psi, n, p):
14301430
logp_val = at.switch(
14311431
at.gt(value, 0),
14321432
at.log(psi) + _logp(binomial, value, {}, n, p),
1433-
logaddexp(at.log1p(-psi), at.log(psi) + n * at.log1p(-p)),
1433+
at.logaddexp(at.log1p(-psi), at.log(psi) + n * at.log1p(-p)),
14341434
)
14351435

14361436
return bound(
@@ -1460,7 +1460,7 @@ def logcdf(value, psi, n, p):
14601460
"""
14611461

14621462
return bound(
1463-
logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(binomial, value, {}, n, p)),
1463+
at.logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(binomial, value, {}, n, p)),
14641464
0 <= value,
14651465
value <= n,
14661466
0 <= psi,
@@ -1583,7 +1583,7 @@ def logp(value, psi, n, p):
15831583
at.switch(
15841584
at.gt(value, 0),
15851585
at.log(psi) + _logp(nbinom, value, {}, n, p),
1586-
logaddexp(at.log1p(-psi), at.log(psi) + n * at.log(p)),
1586+
at.logaddexp(at.log1p(-psi), at.log(psi) + n * at.log(p)),
15871587
),
15881588
0 <= value,
15891589
0 <= psi,
@@ -1609,7 +1609,7 @@ def logcdf(value, psi, n, p):
16091609
TensorVariable
16101610
"""
16111611
return bound(
1612-
logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(nbinom, value, {}, n, p)),
1612+
at.logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(nbinom, value, {}, n, p)),
16131613
0 <= value,
16141614
0 <= psi,
16151615
psi <= 1,

pymc3/math.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import sys
16+
import warnings
1617

1718
from functools import partial, reduce
1819

@@ -50,6 +51,9 @@
5051
gt,
5152
le,
5253
log,
54+
log1pexp,
55+
logaddexp,
56+
logsumexp,
5357
lt,
5458
maximum,
5559
minimum,
@@ -186,27 +190,14 @@ def tround(*args, **kwargs):
186190
return at.round(*args, **kwargs)
187191

188192

189-
def logsumexp(x, axis=None, keepdims=True):
190-
# Adapted from https://github.com/Theano/Theano/issues/1563
191-
x_max = at.max(x, axis=axis, keepdims=True)
192-
x_max = at.switch(at.isinf(x_max), 0, x_max)
193-
res = at.log(at.sum(at.exp(x - x_max), axis=axis, keepdims=True)) + x_max
194-
return res if keepdims else res.squeeze()
195-
196-
197-
def logaddexp(a, b):
198-
diff = b - a
199-
return at.switch(diff > 0, b + at.log1p(at.exp(-diff)), a + at.log1p(at.exp(diff)))
200-
201-
202193
def logdiffexp(a, b):
203194
"""log(exp(a) - exp(b))"""
204-
return a + log1mexp(a - b)
195+
return a + at.log1mexp(b - a)
205196

206197

207198
def logdiffexp_numpy(a, b):
208199
"""log(exp(a) - exp(b))"""
209-
return a + log1mexp_numpy(a - b)
200+
return a + log1mexp_numpy(b - a, negative_input=True)
210201

211202

212203
def invlogit(x, eps=sys.float_info.epsilon):
@@ -224,15 +215,7 @@ def logit(p):
224215
return at.log(p / (floatX(1) - p))
225216

226217

227-
def log1pexp(x):
228-
"""Return log(1 + exp(x)), also called softplus.
229-
230-
This function is numerically more stable than the naive approach.
231-
"""
232-
return at.softplus(x)
233-
234-
235-
def log1mexp(x):
218+
def log1mexp(x, *, negative_input=False):
236219
r"""Return log(1 - exp(-x)).
237220
238221
This function is numerically more stable than the naive approach.
@@ -246,21 +229,40 @@ def log1mexp(x):
246229
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
247230
248231
"""
249-
return at.switch(at.lt(x, 0.6931471805599453), at.log(-at.expm1(-x)), at.log1p(-at.exp(-x)))
232+
if not negative_input:
233+
warnings.warn(
234+
"pymc3.math.log1mexp will expect a negative input in a future "
235+
"version of PyMC3.\n To suppress this warning set `negative_input=True`",
236+
FutureWarning,
237+
stacklevel=2,
238+
)
239+
x = -x
250240

241+
return at.log1mexp(x)
251242

252-
def log1mexp_numpy(x):
253-
"""Return log(1 - exp(-x)).
243+
244+
def log1mexp_numpy(x, *, negative_input=False):
245+
"""Return log(1 - exp(x)).
254246
This function is numerically more stable than the naive approach.
255247
For details, see
256248
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
257249
"""
258-
x = np.asarray(x)
250+
x = np.asarray(x, dtype="float")
251+
252+
if not negative_input:
253+
warnings.warn(
254+
"pymc3.math.log1mexp_numpy will expect a negative input in a future "
255+
"version of PyMC3.\n To suppress this warning set `negative_input=True`",
256+
FutureWarning,
257+
stacklevel=2,
258+
)
259+
x = -x
260+
259261
out = np.empty_like(x)
260-
mask = x < 0.6931471805599453 # log(2)
261-
out[mask] = np.log(-np.expm1(-x[mask]))
262+
mask = x < -0.6931471805599453 # log(1/2)
263+
out[mask] = np.log1p(-np.exp(x[mask]))
262264
mask = ~mask
263-
out[mask] = np.log1p(-np.exp(-x[mask]))
265+
out[mask] = np.log(-np.expm1(x[mask]))
264266
return out
265267

266268

pymc3/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,7 @@ def scipy_log_pdf(value, a, b):
11651165
)
11661166

11671167
def scipy_log_cdf(value, a, b):
1168-
return pm.math.log1mexp_numpy(-(b * np.log1p(-(value ** a))))
1168+
return pm.math.log1mexp_numpy(b * np.log1p(-(value ** a)), negative_input=True)
11691169

11701170
self.check_logp(
11711171
Kumaraswamy,

0 commit comments

Comments
 (0)