Skip to content

Commit 73b4167

Browse files
authored
ENH: Add numba engine to groupby.var/std (#44862)
1 parent 282481f commit 73b4167

File tree

4 files changed

+125
-34
lines changed

4 files changed

+125
-34
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ Other enhancements
219219
- :meth:`DataFrame.dropna` now accepts a single label as ``subset`` along with array-like (:issue:`41021`)
220220
- :class:`ExcelWriter` argument ``if_sheet_exists="overlay"`` option added (:issue:`40231`)
221221
- :meth:`read_excel` now accepts a ``decimal`` argument that allow the user to specify the decimal point when parsing string columns to numeric (:issue:`14403`)
222-
- :meth:`.GroupBy.mean` now supports `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`43731`)
222+
- :meth:`.GroupBy.mean`, :meth:`.GroupBy.std`, and :meth:`.GroupBy.var` now supports `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`43731`, :issue:`44862`)
223223
- :meth:`Timestamp.isoformat`, now handles the ``timespec`` argument from the base :class:``datetime`` class (:issue:`26131`)
224224
- :meth:`NaT.to_numpy` ``dtype`` argument is now respected, so ``np.timedelta64`` can be returned (:issue:`44460`)
225225
- New option ``display.max_dir_items`` customizes the number of columns added to :meth:`Dataframe.__dir__` and suggested for tab completion (:issue:`37996`)

pandas/core/groupby/groupby.py

+78-17
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,7 @@ def _numba_agg_general(
12721272
func: Callable,
12731273
engine_kwargs: dict[str, bool] | None,
12741274
numba_cache_key_str: str,
1275+
*aggregator_args,
12751276
):
12761277
"""
12771278
Perform groupby with a standard numerical aggregation function (e.g. mean)
@@ -1291,7 +1292,7 @@ def _numba_agg_general(
12911292
aggregator = executor.generate_shared_aggregator(
12921293
func, engine_kwargs, numba_cache_key_str
12931294
)
1294-
result = aggregator(sorted_data, starts, ends, 0)
1295+
result = aggregator(sorted_data, starts, ends, 0, *aggregator_args)
12951296

12961297
cache_key = (func, numba_cache_key_str)
12971298
if cache_key not in NUMBA_FUNC_CACHE:
@@ -1989,7 +1990,12 @@ def median(self, numeric_only: bool | lib.NoDefault = lib.no_default):
19891990
@final
19901991
@Substitution(name="groupby")
19911992
@Appender(_common_see_also)
1992-
def std(self, ddof: int = 1):
1993+
def std(
1994+
self,
1995+
ddof: int = 1,
1996+
engine: str | None = None,
1997+
engine_kwargs: dict[str, bool] | None = None,
1998+
):
19931999
"""
19942000
Compute standard deviation of groups, excluding missing values.
19952001
@@ -2000,23 +2006,52 @@ def std(self, ddof: int = 1):
20002006
ddof : int, default 1
20012007
Degrees of freedom.
20022008
2009+
engine : str, default None
2010+
* ``'cython'`` : Runs the operation through C-extensions from cython.
2011+
* ``'numba'`` : Runs the operation through JIT compiled code from numba.
2012+
* ``None`` : Defaults to ``'cython'`` or globally setting
2013+
``compute.use_numba``
2014+
2015+
.. versionadded:: 1.4.0
2016+
2017+
engine_kwargs : dict, default None
2018+
* For ``'cython'`` engine, there are no accepted ``engine_kwargs``
2019+
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil``
2020+
and ``parallel`` dictionary keys. The values must either be ``True`` or
2021+
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is
2022+
``{{'nopython': True, 'nogil': False, 'parallel': False}}``
2023+
2024+
.. versionadded:: 1.4.0
2025+
20032026
Returns
20042027
-------
20052028
Series or DataFrame
20062029
Standard deviation of values within each group.
20072030
"""
2008-
return self._get_cythonized_result(
2009-
libgroupby.group_var,
2010-
needs_counts=True,
2011-
cython_dtype=np.dtype(np.float64),
2012-
post_processing=lambda vals, inference: np.sqrt(vals),
2013-
ddof=ddof,
2014-
)
2031+
if maybe_use_numba(engine):
2032+
from pandas.core._numba.kernels import sliding_var
2033+
2034+
return np.sqrt(
2035+
self._numba_agg_general(sliding_var, engine_kwargs, "groupby_std", ddof)
2036+
)
2037+
else:
2038+
return self._get_cythonized_result(
2039+
libgroupby.group_var,
2040+
needs_counts=True,
2041+
cython_dtype=np.dtype(np.float64),
2042+
post_processing=lambda vals, inference: np.sqrt(vals),
2043+
ddof=ddof,
2044+
)
20152045

20162046
@final
20172047
@Substitution(name="groupby")
20182048
@Appender(_common_see_also)
2019-
def var(self, ddof: int = 1):
2049+
def var(
2050+
self,
2051+
ddof: int = 1,
2052+
engine: str | None = None,
2053+
engine_kwargs: dict[str, bool] | None = None,
2054+
):
20202055
"""
20212056
Compute variance of groups, excluding missing values.
20222057
@@ -2027,20 +2062,46 @@ def var(self, ddof: int = 1):
20272062
ddof : int, default 1
20282063
Degrees of freedom.
20292064
2065+
engine : str, default None
2066+
* ``'cython'`` : Runs the operation through C-extensions from cython.
2067+
* ``'numba'`` : Runs the operation through JIT compiled code from numba.
2068+
* ``None`` : Defaults to ``'cython'`` or globally setting
2069+
``compute.use_numba``
2070+
2071+
.. versionadded:: 1.4.0
2072+
2073+
engine_kwargs : dict, default None
2074+
* For ``'cython'`` engine, there are no accepted ``engine_kwargs``
2075+
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil``
2076+
and ``parallel`` dictionary keys. The values must either be ``True`` or
2077+
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is
2078+
``{{'nopython': True, 'nogil': False, 'parallel': False}}``
2079+
2080+
.. versionadded:: 1.4.0
2081+
20302082
Returns
20312083
-------
20322084
Series or DataFrame
20332085
Variance of values within each group.
20342086
"""
2035-
if ddof == 1:
2036-
numeric_only = self._resolve_numeric_only(lib.no_default)
2037-
return self._cython_agg_general(
2038-
"var", alt=lambda x: Series(x).var(ddof=ddof), numeric_only=numeric_only
2087+
if maybe_use_numba(engine):
2088+
from pandas.core._numba.kernels import sliding_var
2089+
2090+
return self._numba_agg_general(
2091+
sliding_var, engine_kwargs, "groupby_var", ddof
20392092
)
20402093
else:
2041-
func = lambda x: x.var(ddof=ddof)
2042-
with self._group_selection_context():
2043-
return self._python_agg_general(func)
2094+
if ddof == 1:
2095+
numeric_only = self._resolve_numeric_only(lib.no_default)
2096+
return self._cython_agg_general(
2097+
"var",
2098+
alt=lambda x: Series(x).var(ddof=ddof),
2099+
numeric_only=numeric_only,
2100+
)
2101+
else:
2102+
func = lambda x: x.var(ddof=ddof)
2103+
with self._group_selection_context():
2104+
return self._python_agg_general(func)
20442105

20452106
@final
20462107
@Substitution(name="groupby")

pandas/tests/groupby/conftest.py

+14
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,17 @@ def nogil(request):
174174
def nopython(request):
175175
"""nopython keyword argument for numba.jit"""
176176
return request.param
177+
178+
179+
@pytest.fixture(
180+
params=[
181+
("mean", {}),
182+
("var", {"ddof": 1}),
183+
("var", {"ddof": 0}),
184+
("std", {"ddof": 1}),
185+
("std", {"ddof": 0}),
186+
]
187+
)
188+
def numba_supported_reductions(request):
189+
"""reductions supported with engine='numba'"""
190+
return request.param

pandas/tests/groupby/test_numba.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,55 @@
1313
@pytest.mark.filterwarnings("ignore:\n")
1414
# Filter warnings when parallel=True and the function can't be parallelized by Numba
1515
class TestEngine:
16-
def test_cython_vs_numba_frame(self, sort, nogil, parallel, nopython):
16+
def test_cython_vs_numba_frame(
17+
self, sort, nogil, parallel, nopython, numba_supported_reductions
18+
):
19+
func, kwargs = numba_supported_reductions
1720
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
1821
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
19-
result = df.groupby("a", sort=sort).mean(
20-
engine="numba", engine_kwargs=engine_kwargs
22+
gb = df.groupby("a", sort=sort)
23+
result = getattr(gb, func)(
24+
engine="numba", engine_kwargs=engine_kwargs, **kwargs
2125
)
22-
expected = df.groupby("a", sort=sort).mean()
26+
expected = getattr(gb, func)(**kwargs)
2327
tm.assert_frame_equal(result, expected)
2428

25-
def test_cython_vs_numba_getitem(self, sort, nogil, parallel, nopython):
29+
def test_cython_vs_numba_getitem(
30+
self, sort, nogil, parallel, nopython, numba_supported_reductions
31+
):
32+
func, kwargs = numba_supported_reductions
2633
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
2734
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
28-
result = df.groupby("a", sort=sort)["c"].mean(
29-
engine="numba", engine_kwargs=engine_kwargs
35+
gb = df.groupby("a", sort=sort)["c"]
36+
result = getattr(gb, func)(
37+
engine="numba", engine_kwargs=engine_kwargs, **kwargs
3038
)
31-
expected = df.groupby("a", sort=sort)["c"].mean()
39+
expected = getattr(gb, func)(**kwargs)
3240
tm.assert_series_equal(result, expected)
3341

34-
def test_cython_vs_numba_series(self, sort, nogil, parallel, nopython):
42+
def test_cython_vs_numba_series(
43+
self, sort, nogil, parallel, nopython, numba_supported_reductions
44+
):
45+
func, kwargs = numba_supported_reductions
3546
ser = Series(range(3), index=[1, 2, 1], name="foo")
3647
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
37-
result = ser.groupby(level=0, sort=sort).mean(
38-
engine="numba", engine_kwargs=engine_kwargs
48+
gb = ser.groupby(level=0, sort=sort)
49+
result = getattr(gb, func)(
50+
engine="numba", engine_kwargs=engine_kwargs, **kwargs
3951
)
40-
expected = ser.groupby(level=0, sort=sort).mean()
52+
expected = getattr(gb, func)(**kwargs)
4153
tm.assert_series_equal(result, expected)
4254

43-
def test_as_index_false_unsupported(self):
55+
def test_as_index_false_unsupported(self, numba_supported_reductions):
56+
func, kwargs = numba_supported_reductions
4457
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
58+
gb = df.groupby("a", as_index=False)
4559
with pytest.raises(NotImplementedError, match="as_index=False"):
46-
df.groupby("a", as_index=False).mean(engine="numba")
60+
getattr(gb, func)(engine="numba", **kwargs)
4761

48-
def test_axis_1_unsupported(self):
62+
def test_axis_1_unsupported(self, numba_supported_reductions):
63+
func, kwargs = numba_supported_reductions
4964
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
65+
gb = df.groupby("a", axis=1)
5066
with pytest.raises(NotImplementedError, match="axis=1"):
51-
df.groupby("a", axis=1).mean(engine="numba")
67+
getattr(gb, func)(engine="numba", **kwargs)

0 commit comments

Comments
 (0)