Skip to content

Commit 92d5450

Browse files
committed
Implement Polygamma Op
1 parent f6be521 commit 92d5450

File tree

7 files changed

+159
-9
lines changed

7 files changed

+159
-9
lines changed

pytensor/gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def grad_undefined(op, x_pos, x, comment=""):
101101
return (
102102
NullType(
103103
"This variable is Null because the grad method for "
104-
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
104+
f"input {x_pos} ({x}) of the {op} op is undefined. {comment}"
105105
)
106106
)()
107107

pytensor/scalar/math.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import scipy.stats
1414

1515
from pytensor.configdefaults import config
16-
from pytensor.gradient import grad_not_implemented
16+
from pytensor.gradient import grad_not_implemented, grad_undefined
1717
from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
1818
from pytensor.scalar.basic import abs as scalar_abs
1919
from pytensor.scalar.basic import (
@@ -473,8 +473,12 @@ def st_impl(x):
473473
def impl(self, x):
474474
return TriGamma.st_impl(x)
475475

476-
def grad(self, inputs, outputs_gradients):
477-
raise NotImplementedError()
476+
def L_op(self, inputs, outputs, outputs_gradients):
477+
(x,) = inputs
478+
(g_out,) = outputs_gradients
479+
if x in complex_types:
480+
raise NotImplementedError("gradient not implemented for complex types")
481+
return [g_out * polygamma(2, x)]
478482

479483
def c_support_code(self, **kwargs):
480484
# The implementation has been copied from
@@ -541,7 +545,52 @@ def c_code(self, node, name, inp, out, sub):
541545
raise NotImplementedError("only floating point is implemented")
542546

543547

544-
tri_gamma = TriGamma(upgrade_to_float, name="tri_gamma")
548+
# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410
549+
tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma")
550+
551+
552+
class PolyGamma(BinaryScalarOp):
553+
"""Polygamma function of order n evaluated at x.
554+
555+
It corresponds to the (n+1)th derivative of the log gamma function.
556+
557+
TODO: Because the first input is discrete and the output is continuous,
558+
the default elemwise inplace won't work, as it always tries to store the results in the first input.
559+
"""
560+
561+
nfunc_spec = ("scipy.special.polygamma", 2, 1)
562+
563+
@staticmethod
564+
def output_types_preference(n_type, x_type):
565+
if n_type not in discrete_types:
566+
raise TypeError(
567+
f"Polygamma order parameter must be discrete, got {n_type} dtype"
568+
)
569+
# Scipy doesn't support it
570+
return upgrade_to_float_no_complex(x_type)
571+
572+
@staticmethod
573+
def st_impl(n, x):
574+
return scipy.special.polygamma(n, x)
575+
576+
def impl(self, n, x):
577+
return PolyGamma.st_impl(n, x)
578+
579+
def L_op(self, inputs, outputs, output_gradients):
580+
(n, x) = inputs
581+
(g_out,) = output_gradients
582+
if x in complex_types:
583+
raise NotImplementedError("gradient not implemented for complex types")
584+
return [
585+
grad_undefined(self, 0, n),
586+
g_out * self(n + 1, x),
587+
]
588+
589+
def c_code(self, *args, **kwargs):
590+
raise NotImplementedError()
591+
592+
593+
polygamma = PolyGamma(name="polygamma")
545594

546595

547596
class Chi2SF(BinaryScalarOp):

pytensor/tensor/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,11 @@ def tri_gamma(a):
13691369
"""second derivative of the log gamma function"""
13701370

13711371

1372+
@scalar_elemwise
1373+
def polygamma(n, x):
1374+
"""Polygamma function of order n evaluated at x"""
1375+
1376+
13721377
@scalar_elemwise
13731378
def chi2sf(x, k):
13741379
"""chi squared survival function"""
@@ -3008,6 +3013,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
30083013
"psi",
30093014
"digamma",
30103015
"tri_gamma",
3016+
"polygamma",
30113017
"chi2sf",
30123018
"gammainc",
30133019
"gammaincc",

pytensor/tensor/rewriting/math.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from pytensor.tensor.math import abs as at_abs
5353
from pytensor.tensor.math import (
5454
add,
55+
digamma,
5556
dot,
5657
eq,
5758
erf,
@@ -68,7 +69,7 @@
6869
makeKeepDims,
6970
)
7071
from pytensor.tensor.math import max as at_max
71-
from pytensor.tensor.math import maximum, mul, neg
72+
from pytensor.tensor.math import maximum, mul, neg, polygamma
7273
from pytensor.tensor.math import pow as at_pow
7374
from pytensor.tensor.math import (
7475
prod,
@@ -81,7 +82,7 @@
8182
sub,
8283
)
8384
from pytensor.tensor.math import sum as at_sum
84-
from pytensor.tensor.math import true_div
85+
from pytensor.tensor.math import tri_gamma, true_div
8586
from pytensor.tensor.rewriting.basic import (
8687
alloc_like,
8788
broadcasted_by,
@@ -3638,3 +3639,22 @@ def local_useless_conj(fgraph, node):
36383639
x = node.inputs[0]
36393640
if x.type.dtype not in complex_dtypes:
36403641
return [x]
3642+
3643+
3644+
local_polygamma_to_digamma = PatternNodeRewriter(
3645+
(polygamma, 0, "x"),
3646+
(digamma, "x"),
3647+
allow_multiple_clients=True,
3648+
name="local_polygamma_to_digamma",
3649+
)
3650+
3651+
register_specialize(local_polygamma_to_digamma)
3652+
3653+
local_polygamma_to_tri_gamma = PatternNodeRewriter(
3654+
(polygamma, 1, "x"),
3655+
(tri_gamma, "x"),
3656+
allow_multiple_clients=True,
3657+
name="local_polygamma_to_tri_gamma",
3658+
)
3659+
3660+
register_specialize(local_polygamma_to_tri_gamma)

tests/link/jax/test_scalar.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
iv,
2121
log,
2222
log1mexp,
23+
polygamma,
2324
psi,
2425
sigmoid,
2526
softplus,
@@ -178,6 +179,20 @@ def test_tri_gamma():
178179
compare_jax_and_py(fg, [np.array([3.0, 5.0])])
179180

180181

182+
def test_polygamma():
183+
n = vector("n", dtype="int32")
184+
x = vector("x", dtype="float32")
185+
out = polygamma(n, x)
186+
fg = FunctionGraph([n, x], [out])
187+
compare_jax_and_py(
188+
fg,
189+
[
190+
np.array([0, 1, 2]).astype("int32"),
191+
np.array([0.5, 0.9, 2.5]).astype("float32"),
192+
],
193+
)
194+
195+
181196
def test_log1mexp():
182197
x = vector("x")
183198
out = log1mexp(x)

tests/tensor/rewriting/test_math.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
3030
from pytensor.misc.safe_asarray import _asarray
3131
from pytensor.printing import debugprint
32-
from pytensor.scalar import Pow
32+
from pytensor.scalar import PolyGamma, Pow, Psi, TriGamma
3333
from pytensor.tensor import inplace
3434
from pytensor.tensor.basic import Alloc, constant, join, second, switch
3535
from pytensor.tensor.blas import Dot22, Gemv
@@ -69,7 +69,7 @@
6969
from pytensor.tensor.math import max as at_max
7070
from pytensor.tensor.math import maximum
7171
from pytensor.tensor.math import min as at_min
72-
from pytensor.tensor.math import minimum, mul, neg, neq
72+
from pytensor.tensor.math import minimum, mul, neg, neq, polygamma
7373
from pytensor.tensor.math import pow as pt_pow
7474
from pytensor.tensor.math import (
7575
prod,
@@ -4236,3 +4236,19 @@ def test_logdiffexp():
42364236
np.testing.assert_almost_equal(
42374237
f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test))
42384238
)
4239+
4240+
4241+
def test_polygamma_specialization():
4242+
x = vector("x")
4243+
4244+
y1 = polygamma(0, x)
4245+
y2 = polygamma(1, x)
4246+
y3 = polygamma(2, x)
4247+
4248+
fn = pytensor.function(
4249+
[x], [y1, y2, y3], mode=get_default_mode().including("specialize")
4250+
)
4251+
fn_outs = fn.maker.fgraph.outputs
4252+
assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
4253+
assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
4254+
assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)

tests/tensor/test_math.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import pytest
10+
import scipy.special
1011
from numpy.testing import assert_array_equal
1112
from scipy.special import logsumexp as scipy_logsumexp
1213

@@ -64,6 +65,7 @@
6465
cov,
6566
deg2rad,
6667
dense_dot,
68+
digamma,
6769
dot,
6870
eq,
6971
exp,
@@ -93,6 +95,7 @@
9395
neg,
9496
neq,
9597
outer,
98+
polygamma,
9699
power,
97100
ptp,
98101
rad2deg,
@@ -3470,3 +3473,44 @@ def test_dot22_opt(self):
34703473
fn = function([x, y], x @ y, mode="FAST_RUN")
34713474
[node] = fn.maker.fgraph.apply_nodes
34723475
assert isinstance(node.op, Dot22)
3476+
3477+
3478+
class TestPolyGamma:
3479+
def test_basic(self):
3480+
n = vector("n", dtype="int64")
3481+
x = scalar("x")
3482+
3483+
np.testing.assert_allclose(
3484+
polygamma(n, x).eval({n: [0, 1], x: 0.5}),
3485+
scipy.special.polygamma([0, 1], 0.5),
3486+
)
3487+
3488+
def test_continuous_n_raises(self):
3489+
n = scalar("n", dtype="float64")
3490+
with pytest.raises(TypeError, match="must be discrete"):
3491+
polygamma(n, 0.5)
3492+
3493+
def test_complex_x_raises(self):
3494+
x = scalar(dtype="complex128")
3495+
with pytest.raises(TypeError, match="complex argument not supported"):
3496+
polygamma(0, x)
3497+
3498+
def test_output_dtype(self):
3499+
n = scalar("n", dtype="int64")
3500+
polygamma(n, scalar("x", dtype="float32")).dtype == "float32"
3501+
polygamma(n, scalar("x", dtype="float64")).dtype == "float64"
3502+
polygamma(n, scalar("x", dtype="int32")).dtype == "float64"
3503+
3504+
def test_grad_x(self):
3505+
x = scalar("x")
3506+
op_grad = grad(polygamma(0, x), wrt=x)
3507+
ref_grad = grad(digamma(x), wrt=x)
3508+
np.testing.assert_allclose(
3509+
op_grad.eval({x: 0.9}),
3510+
ref_grad.eval({x: 0.9}),
3511+
)
3512+
3513+
def test_grad_n_undefined(self):
3514+
n = scalar(dtype="int64")
3515+
with pytest.raises(NullTypeGradError):
3516+
grad(polygamma(n, 0.5), wrt=n)

0 commit comments

Comments
 (0)