7
7
import os
8
8
9
9
import numpy as np
10
+ import scipy .special
11
+ import scipy .stats
10
12
11
13
from aesara .configdefaults import config
12
14
from aesara .gradient import grad_not_implemented
25
27
)
26
28
27
29
28
- imported_scipy_special = False
29
- try :
30
- import scipy .special
31
- import scipy .stats
32
-
33
- imported_scipy_special = True
34
- # Importing scipy.special may raise ValueError.
35
- # See http://projects.scipy.org/scipy/ticket/1739
36
- except (ImportError , ValueError ):
37
- pass
38
-
39
-
40
30
class Erf (UnaryScalarOp ):
41
31
nfunc_spec = ("scipy.special.erf" , 1 , 1 )
42
32
43
33
def impl (self , x ):
44
- if imported_scipy_special :
45
- return scipy .special .erf (x )
46
- else :
47
- super ().impl (x )
34
+ return scipy .special .erf (x )
48
35
49
36
def L_op (self , inputs , outputs , grads ):
50
37
(x ,) = inputs
@@ -78,10 +65,7 @@ class Erfc(UnaryScalarOp):
78
65
nfunc_spec = ("scipy.special.erfc" , 1 , 1 )
79
66
80
67
def impl (self , x ):
81
- if imported_scipy_special :
82
- return scipy .special .erfc (x )
83
- else :
84
- super ().impl (x )
68
+ return scipy .special .erfc (x )
85
69
86
70
def L_op (self , inputs , outputs , grads ):
87
71
(x ,) = inputs
@@ -130,10 +114,7 @@ class Erfcx(UnaryScalarOp):
130
114
nfunc_spec = ("scipy.special.erfcx" , 1 , 1 )
131
115
132
116
def impl (self , x ):
133
- if imported_scipy_special :
134
- return scipy .special .erfcx (x )
135
- else :
136
- super ().impl (x )
117
+ return scipy .special .erfcx (x )
137
118
138
119
def L_op (self , inputs , outputs , grads ):
139
120
(x ,) = inputs
@@ -195,10 +176,7 @@ class Erfinv(UnaryScalarOp):
195
176
nfunc_spec = ("scipy.special.erfinv" , 1 , 1 )
196
177
197
178
def impl (self , x ):
198
- if imported_scipy_special :
199
- return scipy .special .erfinv (x )
200
- else :
201
- super ().impl (x )
179
+ return scipy .special .erfinv (x )
202
180
203
181
def L_op (self , inputs , outputs , grads ):
204
182
(x ,) = inputs
@@ -232,10 +210,7 @@ class Erfcinv(UnaryScalarOp):
232
210
nfunc_spec = ("scipy.special.erfcinv" , 1 , 1 )
233
211
234
212
def impl (self , x ):
235
- if imported_scipy_special :
236
- return scipy .special .erfcinv (x )
237
- else :
238
- super ().impl (x )
213
+ return scipy .special .erfcinv (x )
239
214
240
215
def L_op (self , inputs , outputs , grads ):
241
216
(x ,) = inputs
@@ -273,10 +248,7 @@ def st_impl(x):
273
248
return scipy .special .gamma (x )
274
249
275
250
def impl (self , x ):
276
- if imported_scipy_special :
277
- return Gamma .st_impl (x )
278
- else :
279
- super ().impl (x )
251
+ return Gamma .st_impl (x )
280
252
281
253
def L_op (self , inputs , outputs , gout ):
282
254
(x ,) = inputs
@@ -315,10 +287,7 @@ def st_impl(x):
315
287
return scipy .special .gammaln (x )
316
288
317
289
def impl (self , x ):
318
- if imported_scipy_special :
319
- return GammaLn .st_impl (x )
320
- else :
321
- super ().impl (x )
290
+ return GammaLn .st_impl (x )
322
291
323
292
def L_op (self , inputs , outputs , grads ):
324
293
(x ,) = inputs
@@ -362,10 +331,7 @@ def st_impl(x):
362
331
return scipy .special .psi (x )
363
332
364
333
def impl (self , x ):
365
- if imported_scipy_special :
366
- return Psi .st_impl (x )
367
- else :
368
- super ().impl (x )
334
+ return Psi .st_impl (x )
369
335
370
336
def L_op (self , inputs , outputs , grads ):
371
337
(x ,) = inputs
@@ -456,10 +422,7 @@ def st_impl(x):
456
422
return scipy .special .polygamma (1 , x )
457
423
458
424
def impl (self , x ):
459
- if imported_scipy_special :
460
- return TriGamma .st_impl (x )
461
- else :
462
- super ().impl (x )
425
+ return TriGamma .st_impl (x )
463
426
464
427
def grad (self , inputs , outputs_gradients ):
465
428
raise NotImplementedError ()
@@ -545,10 +508,7 @@ def st_impl(x, k):
545
508
return scipy .stats .chi2 .sf (x , k )
546
509
547
510
def impl (self , x , k ):
548
- if imported_scipy_special :
549
- return Chi2SF .st_impl (x , k )
550
- else :
551
- super ().impl (x , k )
511
+ return Chi2SF .st_impl (x , k )
552
512
553
513
def c_support_code (self , ** kwargs ):
554
514
with open (os .path .join (os .path .dirname (__file__ ), "c_code" , "gamma.c" )) as f :
@@ -589,10 +549,7 @@ def st_impl(k, x):
589
549
return scipy .special .gammainc (k , x )
590
550
591
551
def impl (self , k , x ):
592
- if imported_scipy_special :
593
- return GammaInc .st_impl (k , x )
594
- else :
595
- super ().impl (k , x )
552
+ return GammaInc .st_impl (k , x )
596
553
597
554
def c_support_code (self , ** kwargs ):
598
555
with open (os .path .join (os .path .dirname (__file__ ), "c_code" , "gamma.c" )) as f :
@@ -633,10 +590,7 @@ def st_impl(k, x):
633
590
return scipy .special .gammaincc (x , k )
634
591
635
592
def impl (self , k , x ):
636
- if imported_scipy_special :
637
- return GammaIncC .st_impl (k , x )
638
- else :
639
- super ().impl (k , x )
593
+ return GammaIncC .st_impl (k , x )
640
594
641
595
def c_support_code (self , ** kwargs ):
642
596
with open (os .path .join (os .path .dirname (__file__ ), "c_code" , "gamma.c" )) as f :
@@ -677,10 +631,7 @@ def st_impl(k, x):
677
631
return scipy .special .gammaincc (k , x ) * scipy .special .gamma (k )
678
632
679
633
def impl (self , k , x ):
680
- if imported_scipy_special :
681
- return GammaU .st_impl (k , x )
682
- else :
683
- super ().impl (k , x )
634
+ return GammaU .st_impl (k , x )
684
635
685
636
def c_support_code (self , ** kwargs ):
686
637
with open (os .path .join (os .path .dirname (__file__ ), "c_code" , "gamma.c" )) as f :
@@ -721,10 +672,7 @@ def st_impl(k, x):
721
672
return scipy .special .gammainc (k , x ) * scipy .special .gamma (k )
722
673
723
674
def impl (self , k , x ):
724
- if imported_scipy_special :
725
- return GammaL .st_impl (k , x )
726
- else :
727
- super ().impl (k , x )
675
+ return GammaL .st_impl (k , x )
728
676
729
677
def c_support_code (self , ** kwargs ):
730
678
with open (os .path .join (os .path .dirname (__file__ ), "c_code" , "gamma.c" )) as f :
@@ -765,10 +713,7 @@ def st_impl(v, x):
765
713
return scipy .special .jv (v , x )
766
714
767
715
def impl (self , v , x ):
768
- if imported_scipy_special :
769
- return self .st_impl (v , x )
770
- else :
771
- super ().impl (v , x )
716
+ return self .st_impl (v , x )
772
717
773
718
def grad (self , inputs , grads ):
774
719
v , x = inputs
@@ -794,10 +739,7 @@ def st_impl(x):
794
739
return scipy .special .j1 (x )
795
740
796
741
def impl (self , x ):
797
- if imported_scipy_special :
798
- return self .st_impl (x )
799
- else :
800
- super ().impl (x )
742
+ return self .st_impl (x )
801
743
802
744
def grad (self , inputs , grads ):
803
745
(x ,) = inputs
@@ -828,10 +770,7 @@ def st_impl(x):
828
770
return scipy .special .j0 (x )
829
771
830
772
def impl (self , x ):
831
- if imported_scipy_special :
832
- return self .st_impl (x )
833
- else :
834
- super ().impl (x )
773
+ return self .st_impl (x )
835
774
836
775
def grad (self , inp , grads ):
837
776
(x ,) = inp
@@ -862,10 +801,7 @@ def st_impl(v, x):
862
801
return scipy .special .iv (v , x )
863
802
864
803
def impl (self , v , x ):
865
- if imported_scipy_special :
866
- return self .st_impl (v , x )
867
- else :
868
- super ().impl (v , x )
804
+ return self .st_impl (v , x )
869
805
870
806
def grad (self , inputs , grads ):
871
807
v , x = inputs
@@ -891,10 +827,7 @@ def st_impl(x):
891
827
return scipy .special .i1 (x )
892
828
893
829
def impl (self , x ):
894
- if imported_scipy_special :
895
- return self .st_impl (x )
896
- else :
897
- super ().impl (x )
830
+ return self .st_impl (x )
898
831
899
832
def grad (self , inputs , grads ):
900
833
(x ,) = inputs
@@ -917,10 +850,7 @@ def st_impl(x):
917
850
return scipy .special .i0 (x )
918
851
919
852
def impl (self , x ):
920
- if imported_scipy_special :
921
- return self .st_impl (x )
922
- else :
923
- super ().impl (x )
853
+ return self .st_impl (x )
924
854
925
855
def grad (self , inp , grads ):
926
856
(x ,) = inp
@@ -939,10 +869,7 @@ class Sigmoid(UnaryScalarOp):
939
869
nfunc_spec = ("scipy.special.expit" , 1 , 1 )
940
870
941
871
def impl (self , x ):
942
- if imported_scipy_special :
943
- return scipy .special .expit (x )
944
- else :
945
- super ().impl (x )
872
+ return scipy .special .expit (x )
946
873
947
874
def grad (self , inp , grads ):
948
875
(x ,) = inp
0 commit comments