2
2
3
3
import numpy as np
4
4
5
- from pytensor import config
6
5
from pytensor .compile .ops import ViewOp
7
6
from pytensor .graph .basic import Variable
8
7
from pytensor .link .numba .dispatch import basic as numba_basic
@@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
137
136
138
137
return numba_basic .numba_njit (
139
138
signature ,
140
- fastmath = config .numba__fastmath ,
141
139
# Functions that call a function pointer can't be cached
142
140
cache = False ,
143
141
)(scalar_op_fn )
@@ -177,19 +175,15 @@ def numba_funcify_Add(op, node, **kwargs):
177
175
signature = create_numba_signature (node , force_scalar = True )
178
176
nary_add_fn = binary_to_nary_func (node .inputs , "add" , "+" )
179
177
180
- return numba_basic .numba_njit (signature , fastmath = config .numba__fastmath )(
181
- nary_add_fn
182
- )
178
+ return numba_basic .numba_njit (signature )(nary_add_fn )
183
179
184
180
185
181
@numba_funcify .register (Mul )
186
182
def numba_funcify_Mul (op , node , ** kwargs ):
187
183
signature = create_numba_signature (node , force_scalar = True )
188
184
nary_add_fn = binary_to_nary_func (node .inputs , "mul" , "*" )
189
185
190
- return numba_basic .numba_njit (signature , fastmath = config .numba__fastmath )(
191
- nary_add_fn
192
- )
186
+ return numba_basic .numba_njit (signature )(nary_add_fn )
193
187
194
188
195
189
@numba_funcify .register (Cast )
@@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):
239
233
240
234
_ = kwargs .pop ("storage_map" , None )
241
235
242
- composite_fn = numba_basic .numba_njit (signature , fastmath = config . numba__fastmath )(
236
+ composite_fn = numba_basic .numba_njit (signature )(
243
237
numba_funcify (op .fgraph , squeeze_output = True , ** kwargs )
244
238
)
245
239
return composite_fn
@@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
267
261
return numba_basic .global_numba_func (reciprocal )
268
262
269
263
270
- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
264
+ @numba_basic .numba_njit
271
265
def sigmoid (x ):
272
266
return 1 / (1 + np .exp (- x ))
273
267
@@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
277
271
return numba_basic .global_numba_func (sigmoid )
278
272
279
273
280
- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
274
+ @numba_basic .numba_njit
281
275
def gammaln (x ):
282
276
return math .lgamma (x )
283
277
@@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
287
281
return numba_basic .global_numba_func (gammaln )
288
282
289
283
290
- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
284
+ @numba_basic .numba_njit
291
285
def logp1mexp (x ):
292
286
if x < np .log (0.5 ):
293
287
return np .log1p (- np .exp (x ))
@@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
300
294
return numba_basic .global_numba_func (logp1mexp )
301
295
302
296
303
- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
297
+ @numba_basic .numba_njit
304
298
def erf (x ):
305
299
return math .erf (x )
306
300
@@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
310
304
return numba_basic .global_numba_func (erf )
311
305
312
306
313
- @numba_basic .numba_njit ( fastmath = config . numba__fastmath )
307
+ @numba_basic .numba_njit
314
308
def erfc (x ):
315
309
return math .erfc (x )
316
310
0 commit comments