Skip to content

Commit 133abe8

Browse files
committed
Implement Kve Op and Kv helper
1 parent 3523bfa commit 133abe8

File tree

5 files changed

+88
-4
lines changed

5 files changed

+88
-4
lines changed

pytensor/link/jax/dispatch/scalar.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
GammaIncInv,
3232
Iv,
3333
Ive,
34+
Kve,
3435
Log1mexp,
3536
Psi,
3637
TriGamma,
@@ -288,9 +289,12 @@ def iv(v, x):
288289

289290
@jax_funcify.register(Ive)
290291
def jax_funcify_Ive(op, **kwargs):
291-
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
292+
return try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
293+
292294

293-
return ive
295+
@jax_funcify.register(Kve)
296+
def jax_funcify_Kve(op, **kwargs):
297+
return try_import_tfp_jax_op(op, jax_op_name="bessel_kve")
294298

295299

296300
@jax_funcify.register(Log1mexp)

pytensor/scalar/math.py

+32
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,38 @@ def c_code(self, *args, **kwargs):
12811281
ive = Ive(upgrade_to_float, name="ive")
12821282

12831283

1284+
class Kve(BinaryScalarOp):
1285+
"""Exponentially scaled modified Bessel function of the second kind of real order v."""
1286+
1287+
nfunc_spec = ("scipy.special.kve", 2, 1)
1288+
1289+
@staticmethod
1290+
def st_impl(v, x):
1291+
return scipy.special.kve(v, x)
1292+
1293+
def impl(self, v, x):
1294+
return self.st_impl(v, x)
1295+
1296+
def L_op(self, inputs, outputs, output_grads):
1297+
v, x = inputs
1298+
[kve_vx] = outputs
1299+
[g_out] = output_grads
1300+
# (1 -v/x) * kve(v, x) - kve(v - 1, x)
1301+
kve_vm1x = self(v - 1, x)
1302+
dx = (1 - v / x) * kve_vx - kve_vm1x
1303+
1304+
return [
1305+
grad_not_implemented(self, 0, v),
1306+
g_out * dx,
1307+
]
1308+
1309+
def c_code(self, *args, **kwargs):
1310+
raise NotImplementedError()
1311+
1312+
1313+
kve = Kve(upgrade_to_float, name="kve")
1314+
1315+
12841316
class Sigmoid(UnaryScalarOp):
12851317
"""
12861318
Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit

pytensor/tensor/math.py

+12
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,16 @@ def ive(v, x):
12291229
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
12301230

12311231

1232+
@scalar_elemwise
1233+
def kve(v, x):
1234+
"""Exponentially scaled modified Bessel function of the second kind of real order v."""
1235+
1236+
1237+
def kv(v, x):
1238+
"""Modified Bessel function of the second kind of real order v."""
1239+
return kve(v, x) * exp(-x)
1240+
1241+
12321242
@scalar_elemwise
12331243
def sigmoid(x):
12341244
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
@@ -3040,6 +3050,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
30403050
"i1",
30413051
"iv",
30423052
"ive",
3053+
"kv",
3054+
"kve",
30433055
"sigmoid",
30443056
"expit",
30453057
"softplus",

tests/link/jax/test_scalar.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
gammainccinv,
2222
gammaincinv,
2323
iv,
24+
kve,
2425
log,
2526
log1mexp,
2627
polygamma,
@@ -157,6 +158,7 @@ def test_erfinv():
157158
(erfcx, (0.7,)),
158159
(erfcinv, (0.7,)),
159160
(iv, (0.3, 0.7)),
161+
(kve, (-2.5, 2.0)),
160162
],
161163
)
162164
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")

tests/tensor/test_math_scipy.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55

6-
from pytensor.gradient import verify_grad
6+
from pytensor.gradient import NullTypeGradError, verify_grad
77
from pytensor.scalar import ScalarLoop
88
from pytensor.tensor.elemwise import Elemwise
99

@@ -18,7 +18,7 @@
1818
from pytensor import tensor as pt
1919
from pytensor.compile.mode import get_default_mode
2020
from pytensor.configdefaults import config
21-
from pytensor.tensor import gammaincc, inplace, vector
21+
from pytensor.tensor import gammaincc, inplace, kv, kve, vector
2222
from tests import unittest_tools as utt
2323
from tests.tensor.utils import (
2424
_good_broadcast_unary_chi2sf,
@@ -1196,3 +1196,37 @@ def test_unused_grad_loop_opt(self, wrt):
11961196
[dd for i, dd in enumerate(expected_dds) if i in wrt],
11971197
rtol=rtol,
11981198
)
1199+
1200+
1201+
def test_kve():
1202+
rng = np.random.default_rng(3772)
1203+
v = vector("v")
1204+
x = vector("x")
1205+
1206+
out = kve(v[:, None], x[None, :])
1207+
test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype)
1208+
test_x = np.linspace(0, 1005, 10, dtype=x.type.dtype)
1209+
1210+
np.testing.assert_allclose(
1211+
out.eval({v: test_v, x: test_x}),
1212+
scipy.special.kve(test_v[:, None], test_x[None, :]),
1213+
)
1214+
1215+
with pytest.raises(NullTypeGradError):
1216+
grad(out.sum(), v)
1217+
1218+
verify_grad(lambda x: kv(4.5, x), [test_x + 0.5], rng=rng)
1219+
1220+
1221+
def test_kv():
1222+
v = vector("v")
1223+
x = vector("x")
1224+
1225+
out = kv(v[:, None], x[None, :])
1226+
test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype)
1227+
test_x = np.linspace(0, 512, 10, dtype=x.type.dtype)
1228+
1229+
np.testing.assert_allclose(
1230+
out.eval({v: test_v, x: test_x}),
1231+
scipy.special.kv(test_v[:, None], test_x[None, :]),
1232+
)

0 commit comments

Comments
 (0)