Skip to content

Commit 7411a08

Browse files
committed
More direct access to special functions
1 parent 0b94be0 commit 7411a08

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

pytensor/scalar/math.py

+31-31
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from textwrap import dedent
1010

1111
import numpy as np
12-
import scipy.special
12+
from scipy import special
1313

1414
from pytensor.configdefaults import config
1515
from pytensor.gradient import grad_not_implemented, grad_undefined
@@ -52,7 +52,7 @@ class Erf(UnaryScalarOp):
5252
nfunc_spec = ("scipy.special.erf", 1, 1)
5353

5454
def impl(self, x):
55-
return scipy.special.erf(x)
55+
return special.erf(x)
5656

5757
def L_op(self, inputs, outputs, grads):
5858
(x,) = inputs
@@ -86,7 +86,7 @@ class Erfc(UnaryScalarOp):
8686
nfunc_spec = ("scipy.special.erfc", 1, 1)
8787

8888
def impl(self, x):
89-
return scipy.special.erfc(x)
89+
return special.erfc(x)
9090

9191
def L_op(self, inputs, outputs, grads):
9292
(x,) = inputs
@@ -113,7 +113,7 @@ def c_code(self, node, name, inp, out, sub):
113113
return f"{z} = erfc(({cast}){x});"
114114

115115

116-
# scipy.special.erfc don't support complex. Why?
116+
# special.erfc don't support complex. Why?
117117
erfc = Erfc(upgrade_to_float_no_complex, name="erfc")
118118

119119

@@ -135,7 +135,7 @@ class Erfcx(UnaryScalarOp):
135135
nfunc_spec = ("scipy.special.erfcx", 1, 1)
136136

137137
def impl(self, x):
138-
return scipy.special.erfcx(x)
138+
return special.erfcx(x)
139139

140140
def L_op(self, inputs, outputs, grads):
141141
(x,) = inputs
@@ -191,7 +191,7 @@ class Erfinv(UnaryScalarOp):
191191
nfunc_spec = ("scipy.special.erfinv", 1, 1)
192192

193193
def impl(self, x):
194-
return scipy.special.erfinv(x)
194+
return special.erfinv(x)
195195

196196
def L_op(self, inputs, outputs, grads):
197197
(x,) = inputs
@@ -226,7 +226,7 @@ class Erfcinv(UnaryScalarOp):
226226
nfunc_spec = ("scipy.special.erfcinv", 1, 1)
227227

228228
def impl(self, x):
229-
return scipy.special.erfcinv(x)
229+
return special.erfcinv(x)
230230

231231
def L_op(self, inputs, outputs, grads):
232232
(x,) = inputs
@@ -261,7 +261,7 @@ class Owens_t(BinaryScalarOp):
261261
nfunc_spec = ("scipy.special.owens_t", 2, 1)
262262

263263
def impl(self, h, a):
264-
return scipy.special.owens_t(h, a)
264+
return special.owens_t(h, a)
265265

266266
def grad(self, inputs, grads):
267267
(h, a) = inputs
@@ -286,7 +286,7 @@ class Gamma(UnaryScalarOp):
286286
nfunc_spec = ("scipy.special.gamma", 1, 1)
287287

288288
def impl(self, x):
289-
return scipy.special.gamma(x)
289+
return special.gamma(x)
290290

291291
def L_op(self, inputs, outputs, gout):
292292
(x,) = inputs
@@ -321,7 +321,7 @@ class GammaLn(UnaryScalarOp):
321321
nfunc_spec = ("scipy.special.gammaln", 1, 1)
322322

323323
def impl(self, x):
324-
return scipy.special.gammaln(x)
324+
return special.gammaln(x)
325325

326326
def L_op(self, inputs, outputs, grads):
327327
(x,) = inputs
@@ -361,7 +361,7 @@ class Psi(UnaryScalarOp):
361361
nfunc_spec = ("scipy.special.psi", 1, 1)
362362

363363
def impl(self, x):
364-
return scipy.special.psi(x)
364+
return special.psi(x)
365365

366366
def L_op(self, inputs, outputs, grads):
367367
(x,) = inputs
@@ -448,7 +448,7 @@ class TriGamma(UnaryScalarOp):
448448
"""
449449

450450
def impl(self, x):
451-
return scipy.special.polygamma(1, x)
451+
return special.polygamma(1, x)
452452

453453
def L_op(self, inputs, outputs, outputs_gradients):
454454
(x,) = inputs
@@ -547,7 +547,7 @@ def output_types_preference(n_type, x_type):
547547
return upgrade_to_float_no_complex(x_type)
548548

549549
def impl(self, n, x):
550-
return scipy.special.polygamma(n, x)
550+
return special.polygamma(n, x)
551551

552552
def L_op(self, inputs, outputs, output_gradients):
553553
(n, x) = inputs
@@ -574,7 +574,7 @@ class GammaInc(BinaryScalarOp):
574574
nfunc_spec = ("scipy.special.gammainc", 2, 1)
575575

576576
def impl(self, k, x):
577-
return scipy.special.gammainc(k, x)
577+
return special.gammainc(k, x)
578578

579579
def grad(self, inputs, grads):
580580
(k, x) = inputs
@@ -621,7 +621,7 @@ class GammaIncC(BinaryScalarOp):
621621
nfunc_spec = ("scipy.special.gammaincc", 2, 1)
622622

623623
def impl(self, k, x):
624-
return scipy.special.gammaincc(k, x)
624+
return special.gammaincc(k, x)
625625

626626
def grad(self, inputs, grads):
627627
(k, x) = inputs
@@ -668,7 +668,7 @@ class GammaIncInv(BinaryScalarOp):
668668
nfunc_spec = ("scipy.special.gammaincinv", 2, 1)
669669

670670
def impl(self, k, x):
671-
return scipy.special.gammaincinv(k, x)
671+
return special.gammaincinv(k, x)
672672

673673
def grad(self, inputs, grads):
674674
(k, x) = inputs
@@ -693,7 +693,7 @@ class GammaIncCInv(BinaryScalarOp):
693693
nfunc_spec = ("scipy.special.gammainccinv", 2, 1)
694694

695695
def impl(self, k, x):
696-
return scipy.special.gammainccinv(k, x)
696+
return special.gammainccinv(k, x)
697697

698698
def grad(self, inputs, grads):
699699
(k, x) = inputs
@@ -928,7 +928,7 @@ class GammaU(BinaryScalarOp):
928928
# Note there is no basic SciPy version so no nfunc_spec.
929929

930930
def impl(self, k, x):
931-
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
931+
return special.gammaincc(k, x) * special.gamma(k)
932932

933933
def c_support_code(self, **kwargs):
934934
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -960,7 +960,7 @@ class GammaL(BinaryScalarOp):
960960
# Note there is no basic SciPy version so no nfunc_spec.
961961

962962
def impl(self, k, x):
963-
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
963+
return special.gammainc(k, x) * special.gamma(k)
964964

965965
def c_support_code(self, **kwargs):
966966
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -992,7 +992,7 @@ class Jv(BinaryScalarOp):
992992
nfunc_spec = ("scipy.special.jv", 2, 1)
993993

994994
def impl(self, v, x):
995-
return scipy.special.jv(v, x)
995+
return special.jv(v, x)
996996

997997
def grad(self, inputs, grads):
998998
v, x = inputs
@@ -1017,7 +1017,7 @@ class J1(UnaryScalarOp):
10171017
nfunc_spec = ("scipy.special.j1", 1, 1)
10181018

10191019
def impl(self, x):
1020-
return scipy.special.j1(x)
1020+
return special.j1(x)
10211021

10221022
def grad(self, inputs, grads):
10231023
(x,) = inputs
@@ -1044,7 +1044,7 @@ class J0(UnaryScalarOp):
10441044
nfunc_spec = ("scipy.special.j0", 1, 1)
10451045

10461046
def impl(self, x):
1047-
return scipy.special.j0(x)
1047+
return special.j0(x)
10481048

10491049
def grad(self, inp, grads):
10501050
(x,) = inp
@@ -1071,7 +1071,7 @@ class Iv(BinaryScalarOp):
10711071
nfunc_spec = ("scipy.special.iv", 2, 1)
10721072

10731073
def impl(self, v, x):
1074-
return scipy.special.iv(v, x)
1074+
return special.iv(v, x)
10751075

10761076
def grad(self, inputs, grads):
10771077
v, x = inputs
@@ -1096,7 +1096,7 @@ class I1(UnaryScalarOp):
10961096
nfunc_spec = ("scipy.special.i1", 1, 1)
10971097

10981098
def impl(self, x):
1099-
return scipy.special.i1(x)
1099+
return special.i1(x)
11001100

11011101
def grad(self, inputs, grads):
11021102
(x,) = inputs
@@ -1118,7 +1118,7 @@ class I0(UnaryScalarOp):
11181118
nfunc_spec = ("scipy.special.i0", 1, 1)
11191119

11201120
def impl(self, x):
1121-
return scipy.special.i0(x)
1121+
return special.i0(x)
11221122

11231123
def grad(self, inp, grads):
11241124
(x,) = inp
@@ -1140,7 +1140,7 @@ class Ive(BinaryScalarOp):
11401140
nfunc_spec = ("scipy.special.ive", 2, 1)
11411141

11421142
def impl(self, v, x):
1143-
return scipy.special.ive(v, x)
1143+
return special.ive(v, x)
11441144

11451145
def grad(self, inputs, grads):
11461146
v, x = inputs
@@ -1165,7 +1165,7 @@ class Kve(BinaryScalarOp):
11651165
nfunc_spec = ("scipy.special.kve", 2, 1)
11661166

11671167
def impl(self, v, x):
1168-
return scipy.special.kve(v, x)
1168+
return special.kve(v, x)
11691169

11701170
def L_op(self, inputs, outputs, output_grads):
11711171
v, x = inputs
@@ -1195,7 +1195,7 @@ class Sigmoid(UnaryScalarOp):
11951195
nfunc_spec = ("scipy.special.expit", 1, 1)
11961196

11971197
def impl(self, x):
1198-
return scipy.special.expit(x)
1198+
return special.expit(x)
11991199

12001200
def grad(self, inp, grads):
12011201
(x,) = inp
@@ -1362,7 +1362,7 @@ class BetaInc(ScalarOp):
13621362
nfunc_spec = ("scipy.special.betainc", 3, 1)
13631363

13641364
def impl(self, a, b, x):
1365-
return scipy.special.betainc(a, b, x)
1365+
return special.betainc(a, b, x)
13661366

13671367
def grad(self, inp, grads):
13681368
a, b, x = inp
@@ -1622,7 +1622,7 @@ class BetaIncInv(ScalarOp):
16221622
nfunc_spec = ("scipy.special.betaincinv", 3, 1)
16231623

16241624
def impl(self, a, b, x):
1625-
return scipy.special.betaincinv(a, b, x)
1625+
return special.betaincinv(a, b, x)
16261626

16271627
def grad(self, inputs, grads):
16281628
(a, b, x) = inputs
@@ -1661,7 +1661,7 @@ class Hyp2F1(ScalarOp):
16611661
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
16621662

16631663
def impl(self, a, b, c, z):
1664-
return scipy.special.hyp2f1(a, b, c, z)
1664+
return special.hyp2f1(a, b, c, z)
16651665

16661666
def grad(self, inputs, grads):
16671667
a, b, c, z = inputs

0 commit comments

Comments
 (0)