Skip to content

Commit 695c23e

Browse files
committed
Do not change input inplace
1 parent 62c6326 commit 695c23e

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

pymc3/math.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,12 @@ def log1mexp_numpy(x):
244244
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
245245
"""
246246
x = np.asarray(x)
247-
mask = x < 0.6931471805599453
248-
x[mask] = np.log(-np.expm1(-x[mask]))
247+
out = np.empty_like(x)
248+
mask = x < 0.6931471805599453 # log(2)
249+
out[mask] = np.log(-np.expm1(-x[mask]))
249250
mask = ~mask
250-
x[mask] = np.log1p(-np.exp(-x[mask]))
251-
return x
251+
out[mask] = np.log1p(-np.exp(-x[mask]))
252+
return out
252253

253254

254255
def flatten_list(tensors):

pymc3/tests/test_math.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def test_log1pexp():
133133

134134
def test_log1mexp():
135135
vals = np.array([-1, 0, 1e-20, 1e-4, 10, 100, 1e20])
136+
vals_ = vals.copy()
136137
# import mpmath
137138
# mpmath.mp.dps = 1000
138139
# [float(mpmath.log(1 - mpmath.exp(-x))) for x in vals]
@@ -151,6 +152,8 @@ def test_log1mexp():
151152
npt.assert_allclose(actual, expected)
152153
actual_ = log1mexp_numpy(vals)
153154
npt.assert_allclose(actual_, expected)
155+
# Check that input was not changed in place
156+
npt.assert_allclose(vals, vals_)
154157

155158

156159
def test_log1mexp_numpy_no_warning():

0 commit comments

Comments
 (0)