Skip to content

Commit 1fc678c

Browse files
committed
Use more specific Numba fastmath flags everywhere
1 parent ab3704b commit 1fc678c

File tree

8 files changed

+63
-42
lines changed

8 files changed

+63
-42
lines changed

doc/extending/creating_a_numba_jax_op.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`:
358358
if mode == "add":
359359
if axis is None or ndim == 1:
360360
361-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
361+
@numba_basic.numba_njit()
362362
def cumop(x):
363363
return np.cumsum(x)
364364
365365
else:
366366
367-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
367+
@numba_basic.numba_njit(boundscheck=False)
368368
def cumop(x):
369369
out_dtype = x.dtype
370370
if x.shape[axis] < 2:
@@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`:
382382
else:
383383
if axis is None or ndim == 1:
384384
385-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
385+
@numba_basic.numba_njit()
386386
def cumop(x):
387387
return np.cumprod(x)
388388
389389
else:
390390
391-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
391+
@numba_basic.numba_njit(boundscheck=False)
392392
def cumop(x):
393393
out_dtype = x.dtype
394394
if x.shape[axis] < 2:

pytensor/link/numba/dispatch/basic.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,23 @@ def global_numba_func(func):
4949
return func
5050

5151

52-
def numba_njit(*args, **kwargs):
52+
def numba_njit(*args, fastmath=None, **kwargs):
5353
kwargs.setdefault("cache", config.numba__cache)
5454
kwargs.setdefault("no_cpython_wrapper", True)
5555
kwargs.setdefault("no_cfunc_wrapper", True)
56+
if fastmath is None:
57+
if config.numba__fastmath:
58+
# Opinionated default on fastmath flags
59+
# https://llvm.org/docs/LangRef.html#fast-math-flags
60+
fastmath = {
61+
"arcp", # Allow Reciprocal
62+
"contract", # Allow floating-point contraction
63+
"afn", # Approximate functions
64+
"reassoc",
65+
"nsz", # no-signed zeros
66+
}
67+
else:
68+
fastmath = False
5669

5770
# Suppress cache warning for internal functions
5871
# We have to add an ansi escape code for optional bold text by numba
@@ -68,9 +81,9 @@ def numba_njit(*args, **kwargs):
6881
)
6982

7083
if len(args) > 0 and callable(args[0]):
71-
return numba.njit(*args[1:], **kwargs)(args[0])
84+
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])
7285

73-
return numba.njit(*args, **kwargs)
86+
return numba.njit(*args, fastmath=fastmath, **kwargs)
7487

7588

7689
def numba_vectorize(*args, **kwargs):

pytensor/link/numba/dispatch/blockwise.py

-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
3232
core_op,
3333
node=core_node,
3434
parent_node=node,
35-
fastmath=_jit_options["fastmath"],
3635
**kwargs,
3736
)
3837
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)

pytensor/link/numba/dispatch/elemwise.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from numba.core.extending import overload
77
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
88

9-
from pytensor import config
109
from pytensor.graph.op import Op
1110
from pytensor.link.numba.dispatch import basic as numba_basic
1211
from pytensor.link.numba.dispatch.basic import (
@@ -281,7 +280,6 @@ def jit_compile_reducer(
281280
res = numba_basic.numba_njit(
282281
*args,
283282
boundscheck=False,
284-
fastmath=config.numba__fastmath,
285283
**kwds,
286284
)(fn)
287285

@@ -315,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
315313
op.scalar_op,
316314
node=scalar_node,
317315
parent_node=node,
318-
fastmath=_jit_options["fastmath"],
319316
**kwargs,
320317
)
321318

@@ -403,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs):
403400

404401
if ndim_input == len(axes):
405402
# Slightly faster than `numba_funcify_CAReduce` for this case
406-
@numba_njit(fastmath=config.numba__fastmath)
403+
@numba_njit
407404
def impl_sum(array):
408405
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
409406

410407
elif len(axes) == 0:
411408
# These cases should be removed by rewrites!
412-
@numba_njit(fastmath=config.numba__fastmath)
409+
@numba_njit
413410
def impl_sum(array):
414411
return np.asarray(array, dtype=out_dtype)
415412

@@ -568,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
568565
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
569566
)
570567

571-
jit_fn = numba_basic.numba_njit(
572-
boundscheck=False, fastmath=config.numba__fastmath
573-
)
568+
jit_fn = numba_basic.numba_njit(boundscheck=False)
574569
reduce_max = jit_fn(reduce_max_py)
575570
reduce_sum = jit_fn(reduce_sum_py)
576571
else:
@@ -602,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
602597
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
603598
)
604599

605-
jit_fn = numba_basic.numba_njit(
606-
boundscheck=False, fastmath=config.numba__fastmath
607-
)
600+
jit_fn = numba_basic.numba_njit(boundscheck=False)
608601
reduce_sum = jit_fn(reduce_sum_py)
609602
else:
610603
reduce_sum = np.sum
@@ -642,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
642635
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
643636
)
644637

645-
jit_fn = numba_basic.numba_njit(
646-
boundscheck=False, fastmath=config.numba__fastmath
647-
)
638+
jit_fn = numba_basic.numba_njit(boundscheck=False)
648639
reduce_max = jit_fn(reduce_max_py)
649640
reduce_sum = jit_fn(reduce_sum_py)
650641
else:

pytensor/link/numba/dispatch/extra_ops.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numba
55
import numpy as np
66

7-
from pytensor import config
87
from pytensor.graph import Apply
98
from pytensor.link.numba.dispatch import basic as numba_basic
109
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
@@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
5049
if mode == "add":
5150
if axis is None or ndim == 1:
5251

53-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
52+
@numba_basic.numba_njit
5453
def cumop(x):
5554
return np.cumsum(x)
5655

5756
else:
5857

59-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
58+
@numba_basic.numba_njit(boundscheck=False)
6059
def cumop(x):
6160
out_dtype = x.dtype
6261
if x.shape[axis] < 2:
@@ -74,13 +73,13 @@ def cumop(x):
7473
else:
7574
if axis is None or ndim == 1:
7675

77-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
76+
@numba_basic.numba_njit
7877
def cumop(x):
7978
return np.cumprod(x)
8079

8180
else:
8281

83-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
82+
@numba_basic.numba_njit(boundscheck=False)
8483
def cumop(x):
8584
out_dtype = x.dtype
8685
if x.shape[axis] < 2:

pytensor/link/numba/dispatch/scalar.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44

5-
from pytensor import config
65
from pytensor.compile.ops import ViewOp
76
from pytensor.graph.basic import Variable
87
from pytensor.link.numba.dispatch import basic as numba_basic
@@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
137136

138137
return numba_basic.numba_njit(
139138
signature,
140-
fastmath=config.numba__fastmath,
141139
# Functions that call a function pointer can't be cached
142140
cache=False,
143141
)(scalar_op_fn)
@@ -177,19 +175,15 @@ def numba_funcify_Add(op, node, **kwargs):
177175
signature = create_numba_signature(node, force_scalar=True)
178176
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
179177

180-
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
181-
nary_add_fn
182-
)
178+
return numba_basic.numba_njit(signature)(nary_add_fn)
183179

184180

185181
@numba_funcify.register(Mul)
186182
def numba_funcify_Mul(op, node, **kwargs):
187183
signature = create_numba_signature(node, force_scalar=True)
188184
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
189185

190-
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
191-
nary_add_fn
192-
)
186+
return numba_basic.numba_njit(signature)(nary_add_fn)
193187

194188

195189
@numba_funcify.register(Cast)
@@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):
239233

240234
_ = kwargs.pop("storage_map", None)
241235

242-
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
236+
composite_fn = numba_basic.numba_njit(signature)(
243237
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
244238
)
245239
return composite_fn
@@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
267261
return numba_basic.global_numba_func(reciprocal)
268262

269263

270-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
264+
@numba_basic.numba_njit
271265
def sigmoid(x):
272266
return 1 / (1 + np.exp(-x))
273267

@@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
277271
return numba_basic.global_numba_func(sigmoid)
278272

279273

280-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
274+
@numba_basic.numba_njit
281275
def gammaln(x):
282276
return math.lgamma(x)
283277

@@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
287281
return numba_basic.global_numba_func(gammaln)
288282

289283

290-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
284+
@numba_basic.numba_njit
291285
def logp1mexp(x):
292286
if x < np.log(0.5):
293287
return np.log1p(-np.exp(x))
@@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
300294
return numba_basic.global_numba_func(logp1mexp)
301295

302296

303-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
297+
@numba_basic.numba_njit
304298
def erf(x):
305299
return math.erf(x)
306300

@@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
310304
return numba_basic.global_numba_func(erf)
311305

312306

313-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
307+
@numba_basic.numba_njit
314308
def erfc(x):
315309
return math.erfc(x)
316310

tests/link/numba/test_basic.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,13 @@ def test_config_options_fastmath():
838838
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
839839
print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__))
840840
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
841-
assert numba_mul_fn.targetoptions["fastmath"] is True
841+
assert numba_mul_fn.targetoptions["fastmath"] == {
842+
"afn",
843+
"arcp",
844+
"contract",
845+
"nsz",
846+
"reassoc",
847+
}
842848

843849

844850
def test_config_options_cached():

tests/link/numba/test_scalar.py

+19
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.graph.basic import Constant
1010
from pytensor.graph.fg import FunctionGraph
1111
from pytensor.scalar.basic import Composite
12+
from pytensor.tensor import tensor
1213
from pytensor.tensor.elemwise import Elemwise
1314
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
1415

@@ -140,3 +141,21 @@ def test_reciprocal(v, dtype):
140141
if not isinstance(i, SharedVariable | Constant)
141142
],
142143
)
144+
145+
146+
@pytest.mark.parametrize("composite", (False, True))
147+
def test_isnan(composite):
148+
# Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath
149+
x = tensor(shape=(2,), dtype="float64")
150+
151+
if composite:
152+
x_scalar = psb.float64()
153+
scalar_out = ~psb.isnan(x_scalar)
154+
out = Elemwise(Composite([x_scalar], [scalar_out]))(x)
155+
else:
156+
out = pt.isnan(x)
157+
158+
compare_numba_and_py(
159+
([x], [out]),
160+
[np.array([1, 0], dtype="float64")],
161+
)

0 commit comments

Comments
 (0)