Skip to content

Commit 7dc0edb

Browse files
mroeschkefangchenli
authored andcommitted
ENH: Add compute.use_numba configuration for automatically using numba (pandas-dev#35182)
1 parent 906efe5 commit 7dc0edb

File tree

11 files changed

+124
-43
lines changed

11 files changed

+124
-43
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ Other enhancements
338338
- :meth:`read_csv` now accepts string values like "0", "0.0", "1", "1.0" as convertible to the nullable boolean dtype (:issue:`34859`)
339339
- :class:`pandas.core.window.ExponentialMovingWindow` now supports a ``times`` argument that allows ``mean`` to be calculated with observations spaced by the timestamps in ``times`` (:issue:`34839`)
340340
- :meth:`DataFrame.agg` and :meth:`Series.agg` now accept named aggregation for renaming the output columns/indexes. (:issue:`26513`)
341+
- ``compute.use_numba`` now exists as a configuration option that utilizes the numba engine when available (:issue:`33966`)
341342

342343
.. ---------------------------------------------------------------------------
343344

pandas/core/config_init.py

+17
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,20 @@ def use_numexpr_cb(key):
5252
expressions.set_use_numexpr(cf.get_option(key))
5353

5454

55+
use_numba_doc = """
56+
: bool
57+
Use the numba engine option for select operations if it is installed,
58+
the default is False
59+
Valid values: False,True
60+
"""
61+
62+
63+
def use_numba_cb(key):
64+
from pandas.core.util import numba_
65+
66+
numba_.set_use_numba(cf.get_option(key))
67+
68+
5569
with cf.config_prefix("compute"):
5670
cf.register_option(
5771
"use_bottleneck",
@@ -63,6 +77,9 @@ def use_numexpr_cb(key):
6377
cf.register_option(
6478
"use_numexpr", True, use_numexpr_doc, validator=is_bool, cb=use_numexpr_cb
6579
)
80+
cf.register_option(
81+
"use_numba", False, use_numba_doc, validator=is_bool, cb=use_numba_cb
82+
)
6683
#
6784
# options from the "display" namespace
6885

pandas/core/groupby/generic.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from pandas.core.util.numba_ import (
8181
NUMBA_FUNC_CACHE,
8282
generate_numba_func,
83+
maybe_use_numba,
8384
split_for_numba,
8485
)
8586

@@ -227,9 +228,7 @@ def apply(self, func, *args, **kwargs):
227228
@doc(
228229
_agg_template, examples=_agg_examples_doc, klass="Series",
229230
)
230-
def aggregate(
231-
self, func=None, *args, engine="cython", engine_kwargs=None, **kwargs
232-
):
231+
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
233232

234233
relabeling = func is None
235234
columns = None
@@ -483,7 +482,7 @@ def _aggregate_named(self, func, *args, **kwargs):
483482

484483
@Substitution(klass="Series")
485484
@Appender(_transform_template)
486-
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
485+
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
487486
func = self._get_cython_func(func) or func
488487

489488
if not isinstance(func, str):
@@ -515,7 +514,7 @@ def _transform_general(
515514
Transform with a non-str `func`.
516515
"""
517516

518-
if engine == "numba":
517+
if maybe_use_numba(engine):
519518
numba_func, cache_key = generate_numba_func(
520519
func, engine_kwargs, kwargs, "groupby_transform"
521520
)
@@ -525,7 +524,7 @@ def _transform_general(
525524
results = []
526525
for name, group in self:
527526
object.__setattr__(group, "name", name)
528-
if engine == "numba":
527+
if maybe_use_numba(engine):
529528
values, index = split_for_numba(group)
530529
res = numba_func(values, index, *args)
531530
if cache_key not in NUMBA_FUNC_CACHE:
@@ -934,13 +933,11 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
934933
@doc(
935934
_agg_template, examples=_agg_examples_doc, klass="DataFrame",
936935
)
937-
def aggregate(
938-
self, func=None, *args, engine="cython", engine_kwargs=None, **kwargs
939-
):
936+
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
940937

941938
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
942939

943-
if engine == "numba":
940+
if maybe_use_numba(engine):
944941
return self._python_agg_general(
945942
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
946943
)
@@ -1385,7 +1382,7 @@ def _transform_general(
13851382
applied = []
13861383
obj = self._obj_with_exclusions
13871384
gen = self.grouper.get_iterator(obj, axis=self.axis)
1388-
if engine == "numba":
1385+
if maybe_use_numba(engine):
13891386
numba_func, cache_key = generate_numba_func(
13901387
func, engine_kwargs, kwargs, "groupby_transform"
13911388
)
@@ -1395,7 +1392,7 @@ def _transform_general(
13951392
for name, group in gen:
13961393
object.__setattr__(group, "name", name)
13971394

1398-
if engine == "numba":
1395+
if maybe_use_numba(engine):
13991396
values, index = split_for_numba(group)
14001397
res = numba_func(values, index, *args)
14011398
if cache_key not in NUMBA_FUNC_CACHE:
@@ -1446,7 +1443,7 @@ def _transform_general(
14461443

14471444
@Substitution(klass="DataFrame")
14481445
@Appender(_transform_template)
1449-
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
1446+
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
14501447

14511448
# optimized transforms
14521449
func = self._get_cython_func(func) or func

pandas/core/groupby/groupby.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class providing the base-class of operations.
6565
from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex
6666
from pandas.core.series import Series
6767
from pandas.core.sorting import get_group_index_sorter
68+
from pandas.core.util.numba_ import maybe_use_numba
6869

6970
_common_see_also = """
7071
See Also
@@ -286,9 +287,10 @@ class providing the base-class of operations.
286287
.. versionchanged:: 1.1.0
287288
*args
288289
Positional arguments to pass to func
289-
engine : str, default 'cython'
290+
engine : str, default None
290291
* ``'cython'`` : Runs the function through C-extensions from cython.
291292
* ``'numba'`` : Runs the function through JIT compiled code from numba.
293+
* ``None`` : Defaults to ``'cython'`` or globally setting ``compute.use_numba``
292294
293295
.. versionadded:: 1.1.0
294296
engine_kwargs : dict, default None
@@ -393,9 +395,10 @@ class providing the base-class of operations.
393395
.. versionchanged:: 1.1.0
394396
*args
395397
Positional arguments to pass to func
396-
engine : str, default 'cython'
398+
engine : str, default None
397399
* ``'cython'`` : Runs the function through C-extensions from cython.
398400
* ``'numba'`` : Runs the function through JIT compiled code from numba.
401+
* ``None`` : Defaults to ``'cython'`` or globally setting ``compute.use_numba``
399402
400403
.. versionadded:: 1.1.0
401404
engine_kwargs : dict, default None
@@ -1063,7 +1066,7 @@ def _python_agg_general(
10631066
# agg_series below assumes ngroups > 0
10641067
continue
10651068

1066-
if engine == "numba":
1069+
if maybe_use_numba(engine):
10671070
result, counts = self.grouper.agg_series(
10681071
obj,
10691072
func,

pandas/core/groupby/ops.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from pandas.core.util.numba_ import (
5959
NUMBA_FUNC_CACHE,
6060
generate_numba_func,
61+
maybe_use_numba,
6162
split_for_numba,
6263
)
6364

@@ -620,7 +621,7 @@ def agg_series(
620621
# Caller is responsible for checking ngroups != 0
621622
assert self.ngroups != 0
622623

623-
if engine == "numba":
624+
if maybe_use_numba(engine):
624625
return self._aggregate_series_pure_python(
625626
obj, func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
626627
)
@@ -678,7 +679,7 @@ def _aggregate_series_pure_python(
678679
**kwargs,
679680
):
680681

681-
if engine == "numba":
682+
if maybe_use_numba(engine):
682683
numba_func, cache_key = generate_numba_func(
683684
func, engine_kwargs, kwargs, "groupby_agg"
684685
)
@@ -691,7 +692,7 @@ def _aggregate_series_pure_python(
691692
splitter = get_splitter(obj, group_index, ngroups, axis=0)
692693

693694
for label, group in splitter:
694-
if engine == "numba":
695+
if maybe_use_numba(engine):
695696
values, index = split_for_numba(group)
696697
res = numba_func(values, index, *args)
697698
if cache_key not in NUMBA_FUNC_CACHE:

pandas/core/util/numba_.py

+13
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,22 @@
1010
from pandas.compat._optional import import_optional_dependency
1111
from pandas.errors import NumbaUtilError
1212

13+
GLOBAL_USE_NUMBA: bool = False
1314
NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = dict()
1415

1516

17+
def maybe_use_numba(engine: Optional[str]) -> bool:
18+
"""Signal whether to use numba routines."""
19+
return engine == "numba" or (engine is None and GLOBAL_USE_NUMBA)
20+
21+
22+
def set_use_numba(enable: bool = False) -> None:
23+
global GLOBAL_USE_NUMBA
24+
if enable:
25+
import_optional_dependency("numba")
26+
GLOBAL_USE_NUMBA = enable
27+
28+
1629
def check_kwargs_and_nopython(
1730
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None
1831
) -> None:

pandas/core/window/rolling.py

+16-21
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import pandas.core.common as com
4040
from pandas.core.construction import extract_array
4141
from pandas.core.indexes.api import Index, MultiIndex, ensure_index
42-
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
42+
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE, maybe_use_numba
4343
from pandas.core.window.common import (
4444
WindowGroupByMixin,
4545
_doc_template,
@@ -1298,10 +1298,11 @@ def count(self):
12981298
objects instead.
12991299
If you are just applying a NumPy reduction function this will
13001300
achieve much better performance.
1301-
engine : str, default 'cython'
1301+
engine : str, default None
13021302
* ``'cython'`` : Runs rolling apply through C-extensions from cython.
13031303
* ``'numba'`` : Runs rolling apply through JIT compiled code from numba.
13041304
Only available when ``raw`` is set to ``True``.
1305+
* ``None`` : Defaults to ``'cython'`` or globally setting ``compute.use_numba``
13051306
13061307
.. versionadded:: 1.0.0
13071308
@@ -1357,18 +1358,7 @@ def apply(
13571358
if not is_bool(raw):
13581359
raise ValueError("raw parameter must be `True` or `False`")
13591360

1360-
if engine == "cython":
1361-
if engine_kwargs is not None:
1362-
raise ValueError("cython engine does not accept engine_kwargs")
1363-
# Cython apply functions handle center, so don't need to use
1364-
# _apply's center handling
1365-
window = self._get_window()
1366-
offset = calculate_center_offset(window) if self.center else 0
1367-
apply_func = self._generate_cython_apply_func(
1368-
args, kwargs, raw, offset, func
1369-
)
1370-
center = False
1371-
elif engine == "numba":
1361+
if maybe_use_numba(engine):
13721362
if raw is False:
13731363
raise ValueError("raw must be `True` when using the numba engine")
13741364
cache_key = (func, "rolling_apply")
@@ -1380,6 +1370,17 @@ def apply(
13801370
args, kwargs, func, engine_kwargs
13811371
)
13821372
center = self.center
1373+
elif engine in ("cython", None):
1374+
if engine_kwargs is not None:
1375+
raise ValueError("cython engine does not accept engine_kwargs")
1376+
# Cython apply functions handle center, so don't need to use
1377+
# _apply's center handling
1378+
window = self._get_window()
1379+
offset = calculate_center_offset(window) if self.center else 0
1380+
apply_func = self._generate_cython_apply_func(
1381+
args, kwargs, raw, offset, func
1382+
)
1383+
center = False
13831384
else:
13841385
raise ValueError("engine must be either 'numba' or 'cython'")
13851386

@@ -2053,13 +2054,7 @@ def count(self):
20532054
@Substitution(name="rolling")
20542055
@Appender(_shared_docs["apply"])
20552056
def apply(
2056-
self,
2057-
func,
2058-
raw=False,
2059-
engine="cython",
2060-
engine_kwargs=None,
2061-
args=None,
2062-
kwargs=None,
2057+
self, func, raw=False, engine=None, engine_kwargs=None, args=None, kwargs=None,
20632058
):
20642059
return super().apply(
20652060
func,

pandas/tests/groupby/aggregate/test_numba.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pandas.errors import NumbaUtilError
55
import pandas.util._test_decorators as td
66

7-
from pandas import DataFrame
7+
from pandas import DataFrame, option_context
88
import pandas._testing as tm
99
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
1010

@@ -113,3 +113,18 @@ def func_2(values, index):
113113
result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs)
114114
expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython")
115115
tm.assert_equal(result, expected)
116+
117+
118+
@td.skip_if_no("numba", "0.46.0")
119+
def test_use_global_config():
120+
def func_1(values, index):
121+
return np.mean(values) - 3.4
122+
123+
data = DataFrame(
124+
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1],
125+
)
126+
grouped = data.groupby(0)
127+
expected = grouped.agg(func_1, engine="numba")
128+
with option_context("compute.use_numba", True):
129+
result = grouped.agg(func_1, engine=None)
130+
tm.assert_frame_equal(expected, result)

pandas/tests/groupby/transform/test_numba.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pandas.errors import NumbaUtilError
44
import pandas.util._test_decorators as td
55

6-
from pandas import DataFrame
6+
from pandas import DataFrame, option_context
77
import pandas._testing as tm
88
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
99

@@ -112,3 +112,18 @@ def func_2(values, index):
112112
result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs)
113113
expected = grouped.transform(lambda x: x + 1, engine="cython")
114114
tm.assert_equal(result, expected)
115+
116+
117+
@td.skip_if_no("numba", "0.46.0")
118+
def test_use_global_config():
119+
def func_1(values, index):
120+
return values + 1
121+
122+
data = DataFrame(
123+
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1],
124+
)
125+
grouped = data.groupby(0)
126+
expected = grouped.transform(func_1, engine="numba")
127+
with option_context("compute.use_numba", True):
128+
result = grouped.transform(func_1, engine=None)
129+
tm.assert_frame_equal(expected, result)

pandas/tests/util/test_numba.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import pytest
2+
3+
import pandas.util._test_decorators as td
4+
5+
from pandas import option_context
6+
7+
8+
@td.skip_if_installed("numba")
9+
def test_numba_not_installed_option_context():
10+
with pytest.raises(ImportError, match="Missing optional"):
11+
with option_context("compute.use_numba", True):
12+
pass

pandas/tests/window/test_numba.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pandas.util._test_decorators as td
55

6-
from pandas import Series
6+
from pandas import Series, option_context
77
import pandas._testing as tm
88
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
99

@@ -75,3 +75,15 @@ def func_2(x):
7575
)
7676
expected = roll.apply(func_1, engine="cython", raw=True)
7777
tm.assert_series_equal(result, expected)
78+
79+
80+
@td.skip_if_no("numba", "0.46.0")
81+
def test_use_global_config():
82+
def f(x):
83+
return np.mean(x) + 2
84+
85+
s = Series(range(10))
86+
with option_context("compute.use_numba", True):
87+
result = s.rolling(2).apply(f, engine=None, raw=True)
88+
expected = s.rolling(2).apply(f, engine="numba", raw=True)
89+
tm.assert_series_equal(expected, result)

0 commit comments

Comments
 (0)