Skip to content

Commit e87ad05

Browse files
authored
ENH: Add engine="numba" to groupby mean (#43731)
1 parent 9e7e7a2 commit e87ad05

File tree

4 files changed

+130
-8
lines changed

4 files changed

+130
-8
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ Other enhancements
182182
- Added :meth:`.ExponentialMovingWindow.sum` (:issue:`13297`)
183183
- :meth:`Series.str.split` now supports a ``regex`` argument that explicitly specifies whether the pattern is a regular expression. Default is ``None`` (:issue:`43563`, :issue:`32835`, :issue:`25549`)
184184
- :meth:`DataFrame.dropna` now accepts a single label as ``subset`` along with array-like (:issue:`41021`)
185-
-
185+
- :meth:`.GroupBy.mean` now supports `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`43731`)
186186

187187
.. ---------------------------------------------------------------------------
188188

pandas/core/groupby/groupby.py

+73-7
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class providing the base-class of operations.
7979
)
8080

8181
from pandas.core import nanops
82+
from pandas.core._numba import executor
8283
import pandas.core.algorithms as algorithms
8384
from pandas.core.arrays import (
8485
BaseMaskedArray,
@@ -1259,6 +1260,44 @@ def _numba_prep(self, func, data):
12591260
sorted_data,
12601261
)
12611262

1263+
def _numba_agg_general(
1264+
self,
1265+
func: Callable,
1266+
engine_kwargs: dict[str, bool] | None,
1267+
numba_cache_key_str: str,
1268+
):
1269+
"""
1270+
Perform groupby with a standard numerical aggregation function (e.g. mean)
1271+
with Numba.
1272+
"""
1273+
if not self.as_index:
1274+
raise NotImplementedError(
1275+
"as_index=False is not supported. Use .reset_index() instead."
1276+
)
1277+
if self.axis == 1:
1278+
raise NotImplementedError("axis=1 is not supported.")
1279+
1280+
with self._group_selection_context():
1281+
data = self._selected_obj
1282+
df = data if data.ndim == 2 else data.to_frame()
1283+
starts, ends, sorted_index, sorted_data = self._numba_prep(func, df)
1284+
aggregator = executor.generate_shared_aggregator(
1285+
func, engine_kwargs, numba_cache_key_str
1286+
)
1287+
result = aggregator(sorted_data, starts, ends, 0)
1288+
1289+
cache_key = (func, numba_cache_key_str)
1290+
if cache_key not in NUMBA_FUNC_CACHE:
1291+
NUMBA_FUNC_CACHE[cache_key] = aggregator
1292+
1293+
index = self.grouper.result_index
1294+
if data.ndim == 1:
1295+
result_kwargs = {"name": data.name}
1296+
result = result.ravel()
1297+
else:
1298+
result_kwargs = {"columns": data.columns}
1299+
return data._constructor(result, index=index, **result_kwargs)
1300+
12621301
@final
12631302
def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
12641303
"""
@@ -1827,7 +1866,12 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
18271866
@final
18281867
@Substitution(name="groupby")
18291868
@Substitution(see_also=_common_see_also)
1830-
def mean(self, numeric_only: bool | lib.NoDefault = lib.no_default):
1869+
def mean(
1870+
self,
1871+
numeric_only: bool | lib.NoDefault = lib.no_default,
1872+
engine: str = "cython",
1873+
engine_kwargs: dict[str, bool] | None = None,
1874+
):
18311875
"""
18321876
Compute mean of groups, excluding missing values.
18331877
@@ -1837,6 +1881,23 @@ def mean(self, numeric_only: bool | lib.NoDefault = lib.no_default):
18371881
Include only float, int, boolean columns. If None, will attempt to use
18381882
everything, then use only numeric data.
18391883
1884+
engine : str, default None
1885+
* ``'cython'`` : Runs the operation through C-extensions from cython.
1886+
* ``'numba'`` : Runs the operation through JIT compiled code from numba.
1887+
* ``None`` : Defaults to ``'cython'`` or globally setting
1888+
``compute.use_numba``
1889+
1890+
.. versionadded:: 1.4.0
1891+
1892+
engine_kwargs : dict, default None
1893+
* For ``'cython'`` engine, there are no accepted ``engine_kwargs``
1894+
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil``
1895+
and ``parallel`` dictionary keys. The values must either be ``True`` or
1896+
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is
1897+
``{{'nopython': True, 'nogil': False, 'parallel': False}}``
1898+
1899+
.. versionadded:: 1.4.0
1900+
18401901
Returns
18411902
-------
18421903
pandas.Series or pandas.DataFrame
@@ -1877,12 +1938,17 @@ def mean(self, numeric_only: bool | lib.NoDefault = lib.no_default):
18771938
"""
18781939
numeric_only = self._resolve_numeric_only(numeric_only)
18791940

1880-
result = self._cython_agg_general(
1881-
"mean",
1882-
alt=lambda x: Series(x).mean(numeric_only=numeric_only),
1883-
numeric_only=numeric_only,
1884-
)
1885-
return result.__finalize__(self.obj, method="groupby")
1941+
if maybe_use_numba(engine):
1942+
from pandas.core._numba.kernels import sliding_mean
1943+
1944+
return self._numba_agg_general(sliding_mean, engine_kwargs, "groupby_mean")
1945+
else:
1946+
result = self._cython_agg_general(
1947+
"mean",
1948+
alt=lambda x: Series(x).mean(numeric_only=numeric_only),
1949+
numeric_only=numeric_only,
1950+
)
1951+
return result.__finalize__(self.obj, method="groupby")
18861952

18871953
@final
18881954
@Substitution(name="groupby")

pandas/tests/groupby/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
)
1313

1414

15+
@pytest.fixture(params=[True, False])
16+
def sort(request):
17+
return request.param
18+
19+
1520
@pytest.fixture(params=[True, False])
1621
def as_index(request):
1722
return request.param

pandas/tests/groupby/test_numba.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
3+
import pandas.util._test_decorators as td
4+
5+
from pandas import (
6+
DataFrame,
7+
Series,
8+
)
9+
import pandas._testing as tm
10+
11+
12+
@td.skip_if_no("numba")
13+
@pytest.mark.filterwarnings("ignore:\\nThe keyword argument")
14+
# Filter warnings when parallel=True and the function can't be parallelized by Numba
15+
class TestEngine:
16+
def test_cython_vs_numba_frame(self, sort, nogil, parallel, nopython):
17+
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
18+
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
19+
result = df.groupby("a", sort=sort).mean(
20+
engine="numba", engine_kwargs=engine_kwargs
21+
)
22+
expected = df.groupby("a", sort=sort).mean()
23+
tm.assert_frame_equal(result, expected)
24+
25+
def test_cython_vs_numba_getitem(self, sort, nogil, parallel, nopython):
26+
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
27+
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
28+
result = df.groupby("a", sort=sort)["c"].mean(
29+
engine="numba", engine_kwargs=engine_kwargs
30+
)
31+
expected = df.groupby("a", sort=sort)["c"].mean()
32+
tm.assert_series_equal(result, expected)
33+
34+
def test_cython_vs_numba_series(self, sort, nogil, parallel, nopython):
35+
ser = Series(range(3), index=[1, 2, 1], name="foo")
36+
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
37+
result = ser.groupby(level=0, sort=sort).mean(
38+
engine="numba", engine_kwargs=engine_kwargs
39+
)
40+
expected = ser.groupby(level=0, sort=sort).mean()
41+
tm.assert_series_equal(result, expected)
42+
43+
def test_as_index_false_unsupported(self):
44+
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
45+
with pytest.raises(NotImplementedError, match="as_index=False"):
46+
df.groupby("a", as_index=False).mean(engine="numba")
47+
48+
def test_axis_1_unsupported(self):
49+
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
50+
with pytest.raises(NotImplementedError, match="axis=1"):
51+
df.groupby("a", axis=1).mean(engine="numba")

0 commit comments

Comments
 (0)