Skip to content

Commit b15ef37

Browse files
authored
ENH: Add numba engine to groupby.sum (#44939)
1 parent 029ae81 commit b15ef37

File tree

4 files changed

+36
-16
lines changed

4 files changed

+36
-16
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ Other enhancements
219219
- :meth:`DataFrame.dropna` now accepts a single label as ``subset`` along with array-like (:issue:`41021`)
220220
- :class:`ExcelWriter` argument ``if_sheet_exists="overlay"`` option added (:issue:`40231`)
221221
- :meth:`read_excel` now accepts a ``decimal`` argument that allow the user to specify the decimal point when parsing string columns to numeric (:issue:`14403`)
222-
- :meth:`.GroupBy.mean`, :meth:`.GroupBy.std`, and :meth:`.GroupBy.var` now supports `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`43731`, :issue:`44862`)
222+
- :meth:`.GroupBy.mean`, :meth:`.GroupBy.std`, :meth:`.GroupBy.var`, :meth:`.GroupBy.sum` now supports `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`43731`, :issue:`44862`, :issue:`44939`)
223223
- :meth:`Timestamp.isoformat`, now handles the ``timespec`` argument from the base :class:``datetime`` class (:issue:`26131`)
224224
- :meth:`NaT.to_numpy` ``dtype`` argument is now respected, so ``np.timedelta64`` can be returned (:issue:`44460`)
225225
- New option ``display.max_dir_items`` customizes the number of columns added to :meth:`Dataframe.__dir__` and suggested for tab completion (:issue:`37996`)

pandas/core/groupby/groupby.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -2163,22 +2163,35 @@ def size(self) -> DataFrame | Series:
21632163
@final
21642164
@doc(_groupby_agg_method_template, fname="sum", no=True, mc=0)
21652165
def sum(
2166-
self, numeric_only: bool | lib.NoDefault = lib.no_default, min_count: int = 0
2166+
self,
2167+
numeric_only: bool | lib.NoDefault = lib.no_default,
2168+
min_count: int = 0,
2169+
engine: str | None = None,
2170+
engine_kwargs: dict[str, bool] | None = None,
21672171
):
2168-
numeric_only = self._resolve_numeric_only(numeric_only)
2172+
if maybe_use_numba(engine):
2173+
from pandas.core._numba.kernels import sliding_sum
21692174

2170-
# If we are grouping on categoricals we want unobserved categories to
2171-
# return zero, rather than the default of NaN which the reindexing in
2172-
# _agg_general() returns. GH #31422
2173-
with com.temp_setattr(self, "observed", True):
2174-
result = self._agg_general(
2175-
numeric_only=numeric_only,
2176-
min_count=min_count,
2177-
alias="add",
2178-
npfunc=np.sum,
2175+
return self._numba_agg_general(
2176+
sliding_sum,
2177+
engine_kwargs,
2178+
"groupby_sum",
21792179
)
2180+
else:
2181+
numeric_only = self._resolve_numeric_only(numeric_only)
21802182

2181-
return self._reindex_output(result, fill_value=0)
2183+
# If we are grouping on categoricals we want unobserved categories to
2184+
# return zero, rather than the default of NaN which the reindexing in
2185+
# _agg_general() returns. GH #31422
2186+
with com.temp_setattr(self, "observed", True):
2187+
result = self._agg_general(
2188+
numeric_only=numeric_only,
2189+
min_count=min_count,
2190+
alias="add",
2191+
npfunc=np.sum,
2192+
)
2193+
2194+
return self._reindex_output(result, fill_value=0)
21822195

21832196
@final
21842197
@doc(_groupby_agg_method_template, fname="prod", no=True, mc=0)

pandas/tests/groupby/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def nopython(request):
183183
("var", {"ddof": 0}),
184184
("std", {"ddof": 1}),
185185
("std", {"ddof": 0}),
186+
("sum", {}),
186187
]
187188
)
188189
def numba_supported_reductions(request):

pandas/tests/groupby/test_numba.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def test_cython_vs_numba_frame(
2424
engine="numba", engine_kwargs=engine_kwargs, **kwargs
2525
)
2626
expected = getattr(gb, func)(**kwargs)
27-
tm.assert_frame_equal(result, expected)
27+
# check_dtype can be removed if GH 44952 is addressed
28+
check_dtype = func != "sum"
29+
tm.assert_frame_equal(result, expected, check_dtype=check_dtype)
2830

2931
def test_cython_vs_numba_getitem(
3032
self, sort, nogil, parallel, nopython, numba_supported_reductions
@@ -37,7 +39,9 @@ def test_cython_vs_numba_getitem(
3739
engine="numba", engine_kwargs=engine_kwargs, **kwargs
3840
)
3941
expected = getattr(gb, func)(**kwargs)
40-
tm.assert_series_equal(result, expected)
42+
# check_dtype can be removed if GH 44952 is addressed
43+
check_dtype = func != "sum"
44+
tm.assert_series_equal(result, expected, check_dtype=check_dtype)
4145

4246
def test_cython_vs_numba_series(
4347
self, sort, nogil, parallel, nopython, numba_supported_reductions
@@ -50,7 +54,9 @@ def test_cython_vs_numba_series(
5054
engine="numba", engine_kwargs=engine_kwargs, **kwargs
5155
)
5256
expected = getattr(gb, func)(**kwargs)
53-
tm.assert_series_equal(result, expected)
57+
# check_dtype can be removed if GH 44952 is addressed
58+
check_dtype = func != "sum"
59+
tm.assert_series_equal(result, expected, check_dtype=check_dtype)
5460

5561
def test_as_index_false_unsupported(self, numba_supported_reductions):
5662
func, kwargs = numba_supported_reductions

0 commit comments

Comments
 (0)