diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 71ea40de0f..d3e5ac11f7 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -31,6 +31,7 @@ GammaIncInv, Iv, Ive, + Kve, Log1mexp, Psi, TriGamma, @@ -288,9 +289,12 @@ def iv(v, x): @jax_funcify.register(Ive) def jax_funcify_Ive(op, **kwargs): - ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive") + return try_import_tfp_jax_op(op, jax_op_name="bessel_ive") + - return ive +@jax_funcify.register(Kve) +def jax_funcify_Kve(op, **kwargs): + return try_import_tfp_jax_op(op, jax_op_name="bessel_kve") @jax_funcify.register(Log1mexp) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index e3379492fa..a5512c6564 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1281,6 +1281,38 @@ def c_code(self, *args, **kwargs): ive = Ive(upgrade_to_float, name="ive") +class Kve(BinaryScalarOp): + """Exponentially scaled modified Bessel function of the second kind of real order v.""" + + nfunc_spec = ("scipy.special.kve", 2, 1) + + @staticmethod + def st_impl(v, x): + return scipy.special.kve(v, x) + + def impl(self, v, x): + return self.st_impl(v, x) + + def L_op(self, inputs, outputs, output_grads): + v, x = inputs + [kve_vx] = outputs + [g_out] = output_grads + # (1 -v/x) * kve(v, x) - kve(v - 1, x) + kve_vm1x = self(v - 1, x) + dx = (1 - v / x) * kve_vx - kve_vm1x + + return [ + grad_not_implemented(self, 0, v), + g_out * dx, + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +kve = Kve(upgrade_to_float, name="kve") + + class Sigmoid(UnaryScalarOp): """ Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index d1e4dc6195..8c86a834ea 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1229,6 +1229,16 @@ def ive(v, x): """Exponentially scaled modified Bessel function of the first kind of order v (real).""" +@scalar_elemwise +def kve(v, x): + """Exponentially scaled modified Bessel function of the second kind of real order v.""" + + +def kv(v, x): + """Modified Bessel function of the second kind of real order v.""" + return kve(v, x) * exp(-x) + + @scalar_elemwise def sigmoid(x): """Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit""" @@ -3040,6 +3050,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): "i1", "iv", "ive", + "kv", + "kve", "sigmoid", "expit", "softplus", diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 16df2d1b08..a5321420d8 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -56,6 +56,7 @@ ge, int_div, isinf, + kve, le, log, log1mexp, @@ -3494,3 +3495,18 @@ def local_useless_conj(fgraph, node): ) register_specialize(local_polygamma_to_tri_gamma) + + +local_log_kv = PatternNodeRewriter( + # Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x + # During stabilize -x is converted to -1.0 * x + (log, (mul, (kve, "v", "x"), (exp, (mul, -1.0, "x")))), + (sub, (log, (kve, "v", "x")), "x"), + allow_multiple_clients=True, + name="local_log_kv", + # Start the rewrite from the less likely kve node + tracks=[kve], + get_nodes=get_clients_at_depth2, +) + +register_stabilize(local_log_kv) diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 0469301791..475062e86c 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -21,6 +21,7 @@ gammainccinv, gammaincinv, iv, + kve, log, log1mexp, polygamma, @@ -157,6 +158,7 @@ def test_erfinv(): (erfcx, (0.7,)), (erfcinv, (0.7,)), (iv, (0.3, 0.7)), + (kve, (-2.5, 2.0)), ], ) @pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability") diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 1160562e62..33c61f48bc 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -61,6 +61,7 @@ ge, gt, int_div, + kv, le, log, log1mexp, @@ -4578,3 +4579,17 @@ def test_local_batched_matmul_to_core_matmul(): x_test = rng.normal(size=(5, 3, 2)) y_test = rng.normal(size=(5, 2, 2)) np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) + + +def test_log_kv_stabilization(): + x = pt.scalar("x") + out = log(kv(4.5, x)) + + # Expression would underflow to -inf without rewrite + mode = get_default_mode().including("stabilize") + # Reference value from mpmath + # mpmath.log(mpmath.besselk(4.5, 1000.0)) + np.testing.assert_allclose( + out.eval({x: 1000.0}, mode=mode), + -1003.2180912984705, + ) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 6ca9279bca..921aae826b 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from pytensor.gradient import verify_grad +from pytensor.gradient import NullTypeGradError, verify_grad from pytensor.scalar import ScalarLoop from pytensor.tensor.elemwise import Elemwise @@ -18,7 +18,7 @@ from pytensor import tensor as pt from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config -from pytensor.tensor import gammaincc, inplace, vector +from pytensor.tensor import gammaincc, inplace, kv, kve, vector from tests import unittest_tools as utt from tests.tensor.utils import ( _good_broadcast_unary_chi2sf, @@ -1196,3 +1196,37 @@ def test_unused_grad_loop_opt(self, wrt): [dd for i, dd in enumerate(expected_dds) if i in wrt], rtol=rtol, ) + + +def test_kve(): + rng = np.random.default_rng(3772) + v = vector("v") + x = vector("x") + + out = kve(v[:, None], x[None, :]) + test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype) + test_x = np.linspace(0, 1005, 10, dtype=x.type.dtype) + + np.testing.assert_allclose( + out.eval({v: test_v, x: test_x}), + scipy.special.kve(test_v[:, None], test_x[None, :]), + ) + + with pytest.raises(NullTypeGradError): + grad(out.sum(), v) + + verify_grad(lambda x: kv(4.5, x), [test_x + 0.5], rng=rng) + + +def test_kv(): + v = vector("v") + x = vector("x") + + out = kv(v[:, None], x[None, :]) + test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype) + test_x = np.linspace(0, 512, 10, dtype=x.type.dtype) + + np.testing.assert_allclose( + out.eval({v: test_v, x: test_x}), + scipy.special.kv(test_v[:, None], test_x[None, :]), + )