Skip to content

Commit f951743

Browse files
authored
Implement betaincinv and gammainc[c]inv functions (#502)
1 parent e969403 commit f951743

File tree

8 files changed

+296
-1
lines changed

8 files changed

+296
-1
lines changed

pytensor/link/jax/dispatch/scalar.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
Sub,
2222
)
2323
from pytensor.scalar.math import (
24+
BetaIncInv,
2425
Erf,
2526
Erfc,
2627
Erfcinv,
2728
Erfcx,
2829
Erfinv,
30+
GammaIncCInv,
31+
GammaIncInv,
2932
Iv,
3033
Ive,
3134
Log1mexp,
@@ -226,6 +229,20 @@ def second(x, y):
226229
return second
227230

228231

232+
@jax_funcify.register(GammaIncInv)
233+
def jax_funcify_GammaIncInv(op, **kwargs):
234+
gammaincinv = try_import_tfp_jax_op(op, jax_op_name="igammainv")
235+
236+
return gammaincinv
237+
238+
239+
@jax_funcify.register(GammaIncCInv)
240+
def jax_funcify_GammaIncCInv(op, **kwargs):
241+
gammainccinv = try_import_tfp_jax_op(op, jax_op_name="igammacinv")
242+
243+
return gammainccinv
244+
245+
229246
@jax_funcify.register(Erf)
230247
def jax_funcify_Erf(op, node, **kwargs):
231248
def erf(x):
@@ -250,6 +267,7 @@ def erfinv(x):
250267
return erfinv
251268

252269

270+
@jax_funcify.register(BetaIncInv)
253271
@jax_funcify.register(Erfcx)
254272
@jax_funcify.register(Erfcinv)
255273
def jax_funcify_from_tfp(op, **kwargs):

pytensor/scalar/math.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,64 @@ def __hash__(self):
733733
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
734734

735735

736+
class GammaIncInv(BinaryScalarOp):
737+
"""
738+
Inverse to the regularized lower incomplete gamma function.
739+
"""
740+
741+
nfunc_spec = ("scipy.special.gammaincinv", 2, 1)
742+
743+
@staticmethod
744+
def st_impl(k, x):
745+
return scipy.special.gammaincinv(k, x)
746+
747+
def impl(self, k, x):
748+
return GammaIncInv.st_impl(k, x)
749+
750+
def grad(self, inputs, grads):
751+
(k, x) = inputs
752+
(gz,) = grads
753+
return [
754+
grad_not_implemented(self, 0, k),
755+
gz * exp(gammaincinv(k, x)) * gamma(k) * (gammaincinv(k, x) ** (1 - k)),
756+
]
757+
758+
def c_code(self, *args, **kwargs):
759+
raise NotImplementedError()
760+
761+
762+
gammaincinv = GammaIncInv(upgrade_to_float, name="gammaincinv")
763+
764+
765+
class GammaIncCInv(BinaryScalarOp):
766+
"""
767+
Inverse to the regularized upper incomplete gamma function.
768+
"""
769+
770+
nfunc_spec = ("scipy.special.gammainccinv", 2, 1)
771+
772+
@staticmethod
773+
def st_impl(k, x):
774+
return scipy.special.gammainccinv(k, x)
775+
776+
def impl(self, k, x):
777+
return GammaIncCInv.st_impl(k, x)
778+
779+
def grad(self, inputs, grads):
780+
(k, x) = inputs
781+
(gz,) = grads
782+
return [
783+
grad_not_implemented(self, 0, k),
784+
gz * -exp(gammainccinv(k, x)) * gamma(k) * (gammainccinv(k, x) ** (1 - k)),
785+
]
786+
787+
def c_code(self, *args, **kwargs):
788+
raise NotImplementedError()
789+
790+
791+
gammainccinv = GammaIncCInv(upgrade_to_float, name="gammainccinv")
792+
793+
736794
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop):
737795
init = [as_scalar(x) if x is not None else None for x in init]
738796
constant = [as_scalar(x) for x in constant]
@@ -1648,6 +1706,43 @@ def inner_loop(
16481706
return grad
16491707

16501708

1709+
class BetaIncInv(ScalarOp):
1710+
"""
1711+
Inverse of the regularized incomplete beta function.
1712+
"""
1713+
1714+
nfunc_spec = ("scipy.special.betaincinv", 3, 1)
1715+
1716+
def impl(self, a, b, x):
1717+
return scipy.special.betaincinv(a, b, x)
1718+
1719+
def grad(self, inputs, grads):
1720+
(a, b, x) = inputs
1721+
(gz,) = grads
1722+
return [
1723+
grad_not_implemented(self, 0, a),
1724+
grad_not_implemented(self, 0, b),
1725+
gz
1726+
* exp(betaln(a, b))
1727+
* ((1 - betaincinv(a, b, x)) ** (1 - b))
1728+
* (betaincinv(a, b, x) ** (1 - a)),
1729+
]
1730+
1731+
def c_code(self, *args, **kwargs):
1732+
raise NotImplementedError()
1733+
1734+
1735+
betaincinv = BetaIncInv(upgrade_to_float_no_complex, name="betaincinv")
1736+
1737+
1738+
def betaln(a, b):
1739+
"""
1740+
Beta function from gamma function.
1741+
"""
1742+
1743+
return gammaln(a) + gammaln(b) - gammaln(a + b)
1744+
1745+
16511746
class Hyp2F1(ScalarOp):
16521747
"""
16531748
Gaussian hypergeometric function ``2F1(a, b; c; z)``.

pytensor/tensor/inplace.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,16 @@ def gammal_inplace(k, x):
283283
"""lower incomplete gamma function"""
284284

285285

286+
@scalar_elemwise
287+
def gammaincinv_inplace(k, x):
288+
"""Inverse to the regularized lower incomplete gamma function"""
289+
290+
291+
@scalar_elemwise
292+
def gammainccinv_inplace(k, x):
293+
"""Inverse of the regularized upper incomplete gamma function"""
294+
295+
286296
@scalar_elemwise
287297
def j0_inplace(x):
288298
"""Bessel function of the first kind of order 0."""
@@ -338,6 +348,11 @@ def betainc_inplace(a, b, x):
338348
"""Regularized incomplete beta function"""
339349

340350

351+
@scalar_elemwise
352+
def betaincinv_inplace(a, b, x):
353+
"""Inverse of the regularized incomplete beta function"""
354+
355+
341356
@scalar_elemwise
342357
def second_inplace(a):
343358
"""Fill `a` with `b`"""

pytensor/tensor/math.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,16 @@ def gammal(k, x):
13851385
"""Lower incomplete gamma function."""
13861386

13871387

1388+
@scalar_elemwise
1389+
def gammaincinv(k, x):
1390+
"""Inverse to the regularized lower incomplete gamma function"""
1391+
1392+
1393+
@scalar_elemwise
1394+
def gammainccinv(k, x):
1395+
"""Inverse of the regularized upper incomplete gamma function"""
1396+
1397+
13881398
@scalar_elemwise
13891399
def hyp2f1(a, b, c, z):
13901400
"""Gaussian hypergeometric function."""
@@ -1451,6 +1461,11 @@ def betainc(a, b, x):
14511461
"""Regularized incomplete beta function"""
14521462

14531463

1464+
@scalar_elemwise
1465+
def betaincinv(a, b, x):
1466+
"""Inverse of the regularized incomplete beta function"""
1467+
1468+
14541469
@scalar_elemwise
14551470
def real(z):
14561471
"""Return real component of complex-valued tensor `z`."""
@@ -3044,6 +3059,8 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y):
30443059
"gammaincc",
30453060
"gammau",
30463061
"gammal",
3062+
"gammaincinv",
3063+
"gammainccinv",
30473064
"j0",
30483065
"j1",
30493066
"jv",
@@ -3057,6 +3074,7 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y):
30573074
"log1pexp",
30583075
"log1mexp",
30593076
"betainc",
3077+
"betaincinv",
30603078
"real",
30613079
"imag",
30623080
"angle",

pytensor/tensor/special.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.graph.basic import Apply
77
from pytensor.link.c.op import COp
88
from pytensor.tensor.basic import as_tensor_variable
9-
from pytensor.tensor.math import gamma, neg, sum
9+
from pytensor.tensor.math import gamma, gammaln, neg, sum
1010

1111

1212
class SoftmaxGrad(COp):
@@ -752,9 +752,27 @@ def factorial(n):
752752
return gamma(n + 1)
753753

754754

755+
def beta(a, b):
756+
"""
757+
Beta function.
758+
759+
"""
760+
return (gamma(a) * gamma(b)) / gamma(a + b)
761+
762+
763+
def betaln(a, b):
764+
"""
765+
Log beta function.
766+
767+
"""
768+
return gammaln(a) + gammaln(b) - gammaln(a + b)
769+
770+
755771
__all__ = [
756772
"softmax",
757773
"log_softmax",
758774
"poch",
759775
"factorial",
776+
"beta",
777+
"betaln",
760778
]

tests/link/jax/test_scalar.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
from pytensor.tensor.elemwise import Elemwise
1212
from pytensor.tensor.math import all as pt_all
1313
from pytensor.tensor.math import (
14+
betaincinv,
1415
cosh,
1516
erf,
1617
erfc,
1718
erfcinv,
1819
erfcx,
1920
erfinv,
21+
gammainccinv,
22+
gammaincinv,
2023
iv,
2124
log,
2225
log1mexp,
@@ -165,6 +168,38 @@ def test_tfp_ops(op, test_values):
165168
compare_jax_and_py(fg, test_values)
166169

167170

171+
def test_betaincinv():
172+
a = vector("a", dtype="float64")
173+
b = vector("b", dtype="float64")
174+
x = vector("x", dtype="float64")
175+
out = betaincinv(a, b, x)
176+
fg = FunctionGraph([a, b, x], [out])
177+
compare_jax_and_py(
178+
fg,
179+
[
180+
np.array([5.5, 7.0]),
181+
np.array([5.5, 7.0]),
182+
np.array([0.25, 0.7]),
183+
],
184+
)
185+
186+
187+
def test_gammaincinv():
188+
k = vector("k", dtype="float64")
189+
x = vector("x", dtype="float64")
190+
out = gammaincinv(k, x)
191+
fg = FunctionGraph([k, x], [out])
192+
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
193+
194+
195+
def test_gammainccinv():
196+
k = vector("k", dtype="float64")
197+
x = vector("x", dtype="float64")
198+
out = gammainccinv(k, x)
199+
fg = FunctionGraph([k, x], [out])
200+
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
201+
202+
168203
def test_psi():
169204
x = scalar("x")
170205
out = psi(x)

0 commit comments

Comments
 (0)