diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index b6128b2e39..7f4193a4d6 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,5 +1,14 @@ # Release Notes +## PyMC3 vNext (on deck) + +### Breaking Changes + +### New Features + +### Maintenance +- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)). + ## PyMC3 3.11.0 (21 January 2021) This release breaks some APIs w.r.t. `3.10.0`. It also brings some dreadfully awaited fixes, so be sure to go through the (breaking) changes below. diff --git a/pymc3/math.py b/pymc3/math.py index fc2a55823c..aff54d13b7 100644 --- a/pymc3/math.py +++ b/pymc3/math.py @@ -243,7 +243,13 @@ def log1mexp_numpy(x): For details, see https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf """ - return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x))) + x = np.asarray(x) + out = np.empty_like(x) + mask = x < 0.6931471805599453 # log(2) + out[mask] = np.log(-np.expm1(-x[mask])) + mask = ~mask + out[mask] = np.log1p(-np.exp(-x[mask])) + return out def flatten_list(tensors): diff --git a/pymc3/tests/test_math.py b/pymc3/tests/test_math.py index 74a782555e..b31319021f 100644 --- a/pymc3/tests/test_math.py +++ b/pymc3/tests/test_math.py @@ -133,6 +133,7 @@ def test_log1pexp(): def test_log1mexp(): vals = np.array([-1, 0, 1e-20, 1e-4, 10, 100, 1e20]) + vals_ = vals.copy() # import mpmath # mpmath.mp.dps = 1000 # [float(mpmath.log(1 - mpmath.exp(-x))) for x in vals] @@ -151,6 +152,15 @@ def test_log1mexp(): npt.assert_allclose(actual, expected) actual_ = log1mexp_numpy(vals) npt.assert_allclose(actual_, expected) + # Check that input was not changed in place + npt.assert_allclose(vals, vals_) + + +def test_log1mexp_numpy_no_warning(): + """Assert RuntimeWarning is not raised for very small numbers""" + with pytest.warns(None) as record: + log1mexp_numpy(1e-25) + assert not record class TestLogDet(SeededTest):