Skip to content

Commit e8ff917

Browse files
mroeschkeyehoshuadimarsky
authored andcommitted
ENH: Add numba engine to groupby.min/max (pandas-dev#45428)
1 parent b09d629 commit e8ff917

File tree

4 files changed

+48
-12
lines changed

4 files changed

+48
-12
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Other enhancements
3535
- :class:`StringArray` now accepts array-likes containing nan-likes (``None``, ``np.nan``) for the ``values`` parameter in its constructor in addition to strings and :attr:`pandas.NA`. (:issue:`40839`)
3636
- Improved the rendering of ``categories`` in :class:`CategoricalIndex` (:issue:`45218`)
3737
- :meth:`to_numeric` now preserves float64 arrays when downcasting would generate values not representable in float32 (:issue:`43693`)
38+
- :meth:`.GroupBy.min` and :meth:`.GroupBy.max` now supports `Numba <https://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`45428`)
3839
-
3940

4041
.. ---------------------------------------------------------------------------

pandas/core/groupby/groupby.py

+40-8
Original file line numberDiff line numberDiff line change
@@ -2208,17 +2208,49 @@ def prod(
22082208

22092209
@final
22102210
@doc(_groupby_agg_method_template, fname="min", no=False, mc=-1)
2211-
def min(self, numeric_only: bool = False, min_count: int = -1):
2212-
return self._agg_general(
2213-
numeric_only=numeric_only, min_count=min_count, alias="min", npfunc=np.min
2214-
)
2211+
def min(
2212+
self,
2213+
numeric_only: bool = False,
2214+
min_count: int = -1,
2215+
engine: str | None = None,
2216+
engine_kwargs: dict[str, bool] | None = None,
2217+
):
2218+
if maybe_use_numba(engine):
2219+
from pandas.core._numba.kernels import sliding_min_max
2220+
2221+
return self._numba_agg_general(
2222+
sliding_min_max, engine_kwargs, "groupby_min", False
2223+
)
2224+
else:
2225+
return self._agg_general(
2226+
numeric_only=numeric_only,
2227+
min_count=min_count,
2228+
alias="min",
2229+
npfunc=np.min,
2230+
)
22152231

22162232
@final
22172233
@doc(_groupby_agg_method_template, fname="max", no=False, mc=-1)
2218-
def max(self, numeric_only: bool = False, min_count: int = -1):
2219-
return self._agg_general(
2220-
numeric_only=numeric_only, min_count=min_count, alias="max", npfunc=np.max
2221-
)
2234+
def max(
2235+
self,
2236+
numeric_only: bool = False,
2237+
min_count: int = -1,
2238+
engine: str | None = None,
2239+
engine_kwargs: dict[str, bool] | None = None,
2240+
):
2241+
if maybe_use_numba(engine):
2242+
from pandas.core._numba.kernels import sliding_min_max
2243+
2244+
return self._numba_agg_general(
2245+
sliding_min_max, engine_kwargs, "groupby_max", True
2246+
)
2247+
else:
2248+
return self._agg_general(
2249+
numeric_only=numeric_only,
2250+
min_count=min_count,
2251+
alias="max",
2252+
npfunc=np.max,
2253+
)
22222254

22232255
@final
22242256
@doc(_groupby_agg_method_template, fname="first", no=False, mc=-1)

pandas/tests/groupby/conftest.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,10 @@ def nopython(request):
184184
("std", {"ddof": 1}),
185185
("std", {"ddof": 0}),
186186
("sum", {}),
187-
]
187+
("min", {}),
188+
("max", {}),
189+
],
190+
ids=["mean", "var_1", "var_0", "std_1", "std_0", "sum", "min", "max"],
188191
)
189192
def numba_supported_reductions(request):
190193
"""reductions supported with engine='numba'"""

pandas/tests/groupby/test_numba.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_cython_vs_numba_frame(
2525
)
2626
expected = getattr(gb, func)(**kwargs)
2727
# check_dtype can be removed if GH 44952 is addressed
28-
check_dtype = func != "sum"
28+
check_dtype = func not in ("sum", "min", "max")
2929
tm.assert_frame_equal(result, expected, check_dtype=check_dtype)
3030

3131
def test_cython_vs_numba_getitem(
@@ -40,7 +40,7 @@ def test_cython_vs_numba_getitem(
4040
)
4141
expected = getattr(gb, func)(**kwargs)
4242
# check_dtype can be removed if GH 44952 is addressed
43-
check_dtype = func != "sum"
43+
check_dtype = func not in ("sum", "min", "max")
4444
tm.assert_series_equal(result, expected, check_dtype=check_dtype)
4545

4646
def test_cython_vs_numba_series(
@@ -55,7 +55,7 @@ def test_cython_vs_numba_series(
5555
)
5656
expected = getattr(gb, func)(**kwargs)
5757
# check_dtype can be removed if GH 44952 is addressed
58-
check_dtype = func != "sum"
58+
check_dtype = func not in ("sum", "min", "max")
5959
tm.assert_series_equal(result, expected, check_dtype=check_dtype)
6060

6161
def test_as_index_false_unsupported(self, numba_supported_reductions):

0 commit comments

Comments
 (0)