Skip to content

Commit 3fb5807

Browse files
committed
Deprecate eps argument in math.invlogit
1 parent 4dd0538 commit 3fb5807

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

pymc3/distributions/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class LogOdds(ElemwiseTransform):
170170
name = "logodds"
171171

172172
def backward(self, rv_var, rv_value):
173-
return invlogit(rv_value, 0.0)
173+
return invlogit(rv_value)
174174

175175
def forward(self, rv_var, rv_value):
176176
return logit(rv_value)

pymc3/math.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,15 @@ def logdiffexp_numpy(a, b):
200200
return a + log1mexp_numpy(b - a, negative_input=True)
201201

202202

203-
def invlogit(x, eps=sys.float_info.epsilon):
203+
def invlogit(x, eps=None):
204204
"""The inverse of the logit function, 1 / (1 + exp(-x))."""
205-
return (1.0 - 2.0 * eps) / (1.0 + at.exp(-x)) + eps
205+
if eps is not None:
206+
warnings.warn(
207+
"pymc3.math.invlogit no longer supports the ``eps`` argument and it will be ignored.",
208+
DeprecationWarning,
209+
stacklevel=2,
210+
)
211+
return at.sigmoid(x)
206212

207213

208214
def logbern(log_p):

pymc3/tests/test_math.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
LogDet,
2424
cartesian,
2525
expand_packed_triangular,
26+
invlogit,
2627
invprobit,
2728
kron_dot,
2829
kron_solve_lower,
@@ -250,3 +251,17 @@ def test_expand_packed_triangular():
250251
assert np.all(expand_upper.eval({packed: upper_packed}) == upper)
251252
assert np.all(expand_diag_lower.eval({packed: lower_packed}) == floatX(np.diag(vals)))
252253
assert np.all(expand_diag_upper.eval({packed: upper_packed}) == floatX(np.diag(vals)))
254+
255+
256+
def test_invlogit_deprecation_warning():
257+
with pytest.warns(
258+
DeprecationWarning,
259+
match="pymc3.math.invlogit no longer supports the",
260+
):
261+
res = invlogit(np.array(-750.0), 1e-5).eval()
262+
263+
with pytest.warns(None) as record:
264+
res_zero_eps = invlogit(np.array(-750.0)).eval()
265+
assert not record
266+
267+
assert np.isclose(res, res_zero_eps)

0 commit comments

Comments
 (0)