Skip to content

Commit ad55b69

Browse files
Implement tensor.special.logit helper (#645)
1 parent d28a5d0 commit ad55b69

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

pytensor/tensor/special.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytensor.link.c.op import COp
99
from pytensor.tensor.basic import as_tensor_variable
1010
from pytensor.tensor.elemwise import get_normalized_batch_axes
11-
from pytensor.tensor.math import gamma, gammaln, neg, sum
11+
from pytensor.tensor.math import gamma, gammaln, log, neg, sum
1212

1313

1414
class SoftmaxGrad(COp):
@@ -780,6 +780,14 @@ def factorial(n):
780780
return gamma(n + 1)
781781

782782

783+
def logit(x):
784+
"""
785+
Logit function.
786+
787+
"""
788+
return log(x / (1 - x))
789+
790+
783791
def beta(a, b):
784792
"""
785793
Beta function.
@@ -801,6 +809,7 @@ def betaln(a, b):
801809
"log_softmax",
802810
"poch",
803811
"factorial",
812+
"logit",
804813
"beta",
805814
"betaln",
806815
]

tests/tensor/test_special.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from scipy.special import beta as scipy_beta
44
from scipy.special import factorial as scipy_factorial
55
from scipy.special import log_softmax as scipy_log_softmax
6+
from scipy.special import logit as scipy_logit
67
from scipy.special import poch as scipy_poch
78
from scipy.special import softmax as scipy_softmax
89

@@ -18,6 +19,7 @@
1819
betaln,
1920
factorial,
2021
log_softmax,
22+
logit,
2123
poch,
2224
softmax,
2325
)
@@ -206,6 +208,18 @@ def test_factorial(n):
206208
)
207209

208210

211+
def test_logit():
212+
x = vector("x")
213+
actual_fn = function([x], logit(x), allow_input_downcast=True)
214+
215+
x_test = np.linspace(0, 1)
216+
actual = actual_fn(x_test)
217+
expected = scipy_logit(x_test)
218+
np.testing.assert_allclose(
219+
actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5
220+
)
221+
222+
209223
def test_beta():
210224
_a, _b = vectors("a", "b")
211225
actual_fn = function([_a, _b], beta(_a, _b))

0 commit comments

Comments
 (0)