Skip to content

Commit 00a6aea

Browse files
lithomas1topper-123
authored andcommitted
ENH: Groupby agg support multiple funcs numba (pandas-dev#53486)
* ENH: Groupby agg support multiple funcs numba * address code review * remove old TODO
1 parent fe37123 commit 00a6aea

File tree

4 files changed

+155
-24
lines changed

4 files changed

+155
-24
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ Other enhancements
9999
- :meth:`Categorical.from_codes` has gotten a ``validate`` parameter (:issue:`50975`)
100100
- :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
101101
- :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
102+
- :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`)
102103
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
103104
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
104105
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)

pandas/core/apply.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,18 @@ def agg_dict_like(self) -> DataFrame | Series:
403403
and selected_obj.columns.nunique() < len(selected_obj.columns)
404404
)
405405

406+
# Numba Groupby engine/engine-kwargs passthrough
407+
kwargs = {}
408+
if is_groupby:
409+
engine = self.kwargs.get("engine", None)
410+
engine_kwargs = self.kwargs.get("engine_kwargs", None)
411+
kwargs = {"engine": engine, "engine_kwargs": engine_kwargs}
412+
406413
with context_manager:
407414
if selected_obj.ndim == 1:
408415
# key only used for output
409416
colg = obj._gotitem(selection, ndim=1)
410-
result_data = [colg.agg(how) for _, how in func.items()]
417+
result_data = [colg.agg(how, **kwargs) for _, how in func.items()]
411418
result_index = list(func.keys())
412419
elif is_non_unique_col:
413420
# key used for column selection and output
@@ -422,7 +429,7 @@ def agg_dict_like(self) -> DataFrame | Series:
422429
label_to_indices[label].append(index)
423430

424431
key_data = [
425-
selected_obj._ixs(indice, axis=1).agg(how)
432+
selected_obj._ixs(indice, axis=1).agg(how, **kwargs)
426433
for label, indices in label_to_indices.items()
427434
for indice in indices
428435
]
@@ -432,7 +439,8 @@ def agg_dict_like(self) -> DataFrame | Series:
432439
else:
433440
# key used for column selection and output
434441
result_data = [
435-
obj._gotitem(key, ndim=1).agg(how) for key, how in func.items()
442+
obj._gotitem(key, ndim=1).agg(how, **kwargs)
443+
for key, how in func.items()
436444
]
437445
result_index = list(func.keys())
438446

pandas/core/groupby/generic.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -223,24 +223,26 @@ def apply(self, func, *args, **kwargs) -> Series:
223223

224224
@doc(_agg_template_series, examples=_agg_examples_doc, klass="Series")
225225
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
226-
if maybe_use_numba(engine):
227-
return self._aggregate_with_numba(
228-
func, *args, engine_kwargs=engine_kwargs, **kwargs
229-
)
230-
231226
relabeling = func is None
232227
columns = None
233228
if relabeling:
234229
columns, func = validate_func_kwargs(kwargs)
235230
kwargs = {}
236231

237232
if isinstance(func, str):
233+
if maybe_use_numba(engine):
234+
# Not all agg functions support numba, only propagate numba kwargs
235+
# if user asks for numba
236+
kwargs["engine"] = engine
237+
kwargs["engine_kwargs"] = engine_kwargs
238238
return getattr(self, func)(*args, **kwargs)
239239

240240
elif isinstance(func, abc.Iterable):
241241
# Catch instances of lists / tuples
242242
# but not the class list / tuple itself.
243243
func = maybe_mangle_lambdas(func)
244+
kwargs["engine"] = engine
245+
kwargs["engine_kwargs"] = engine_kwargs
244246
ret = self._aggregate_multiple_funcs(func, *args, **kwargs)
245247
if relabeling:
246248
# columns is not narrowed by mypy from relabeling flag
@@ -255,6 +257,11 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
255257
if cyfunc and not args and not kwargs:
256258
return getattr(self, cyfunc)()
257259

260+
if maybe_use_numba(engine):
261+
return self._aggregate_with_numba(
262+
func, *args, engine_kwargs=engine_kwargs, **kwargs
263+
)
264+
258265
if self.ngroups == 0:
259266
# e.g. test_evaluate_with_empty_groups without any groups to
260267
# iterate over, we have no output on which to do dtype
@@ -1387,14 +1394,15 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
13871394

13881395
@doc(_agg_template_frame, examples=_agg_examples_doc, klass="DataFrame")
13891396
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
1390-
if maybe_use_numba(engine):
1391-
return self._aggregate_with_numba(
1392-
func, *args, engine_kwargs=engine_kwargs, **kwargs
1393-
)
1394-
13951397
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
13961398
func = maybe_mangle_lambdas(func)
13971399

1400+
if maybe_use_numba(engine):
1401+
# Not all agg functions support numba, only propagate numba kwargs
1402+
# if user asks for numba
1403+
kwargs["engine"] = engine
1404+
kwargs["engine_kwargs"] = engine_kwargs
1405+
13981406
op = GroupByApply(self, func, args=args, kwargs=kwargs)
13991407
result = op.agg()
14001408
if not is_dict_like(func) and result is not None:
@@ -1416,6 +1424,17 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
14161424
result.columns = columns # type: ignore[assignment]
14171425

14181426
if result is None:
1427+
# Remove the kwargs we inserted
1428+
# (already stored in engine, engine_kwargs arguments)
1429+
if "engine" in kwargs:
1430+
del kwargs["engine"]
1431+
del kwargs["engine_kwargs"]
1432+
# at this point func is not a str, list-like, dict-like,
1433+
# or a known callable(e.g. sum)
1434+
if maybe_use_numba(engine):
1435+
return self._aggregate_with_numba(
1436+
func, *args, engine_kwargs=engine_kwargs, **kwargs
1437+
)
14191438
# grouper specific aggregations
14201439
if self.grouper.nkeys > 1:
14211440
# test_groupby_as_index_series_scalar gets here with 'not self.as_index'

pandas/tests/groupby/aggregate/test_numba.py

+114-11
Original file line numberDiff line numberDiff line change
@@ -135,24 +135,127 @@ def func_1(values, index):
135135

136136
@td.skip_if_no("numba")
137137
@pytest.mark.parametrize(
138-
"agg_func",
138+
"agg_kwargs",
139139
[
140-
["min", "max"],
141-
"min",
142-
{"B": ["min", "max"], "C": "sum"},
143-
NamedAgg(column="B", aggfunc="min"),
140+
{"func": ["min", "max"]},
141+
{"func": "min"},
142+
{"func": {1: ["min", "max"], 2: "sum"}},
143+
{"bmin": NamedAgg(column=1, aggfunc="min")},
144144
],
145145
)
146-
def test_multifunc_notimplimented(agg_func):
146+
def test_multifunc_numba_vs_cython_frame(agg_kwargs):
147147
data = DataFrame(
148-
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
148+
{
149+
0: ["a", "a", "b", "b", "a"],
150+
1: [1.0, 2.0, 3.0, 4.0, 5.0],
151+
2: [1, 2, 3, 4, 5],
152+
},
153+
columns=[0, 1, 2],
154+
)
155+
grouped = data.groupby(0)
156+
result = grouped.agg(**agg_kwargs, engine="numba")
157+
expected = grouped.agg(**agg_kwargs, engine="cython")
158+
# check_dtype can be removed if GH 44952 is addressed
159+
tm.assert_frame_equal(result, expected, check_dtype=False)
160+
161+
162+
@td.skip_if_no("numba")
163+
@pytest.mark.parametrize(
164+
"agg_kwargs,expected_func",
165+
[
166+
({"func": lambda values, index: values.sum()}, "sum"),
167+
# FIXME
168+
pytest.param(
169+
{
170+
"func": [
171+
lambda values, index: values.sum(),
172+
lambda values, index: values.min(),
173+
]
174+
},
175+
["sum", "min"],
176+
marks=pytest.mark.xfail(
177+
reason="This doesn't work yet! Fails in nopython pipeline!"
178+
),
179+
),
180+
],
181+
)
182+
def test_multifunc_numba_udf_frame(agg_kwargs, expected_func):
183+
data = DataFrame(
184+
{
185+
0: ["a", "a", "b", "b", "a"],
186+
1: [1.0, 2.0, 3.0, 4.0, 5.0],
187+
2: [1, 2, 3, 4, 5],
188+
},
189+
columns=[0, 1, 2],
149190
)
150191
grouped = data.groupby(0)
151-
with pytest.raises(NotImplementedError, match="Numba engine can"):
152-
grouped.agg(agg_func, engine="numba")
192+
result = grouped.agg(**agg_kwargs, engine="numba")
193+
expected = grouped.agg(expected_func, engine="cython")
194+
# check_dtype can be removed if GH 44952 is addressed
195+
tm.assert_frame_equal(result, expected, check_dtype=False)
153196

154-
with pytest.raises(NotImplementedError, match="Numba engine can"):
155-
grouped[1].agg(agg_func, engine="numba")
197+
198+
@td.skip_if_no("numba")
199+
@pytest.mark.parametrize(
200+
"agg_kwargs",
201+
[{"func": ["min", "max"]}, {"func": "min"}, {"min_val": "min", "max_val": "max"}],
202+
)
203+
def test_multifunc_numba_vs_cython_series(agg_kwargs):
204+
labels = ["a", "a", "b", "b", "a"]
205+
data = Series([1.0, 2.0, 3.0, 4.0, 5.0])
206+
grouped = data.groupby(labels)
207+
agg_kwargs["engine"] = "numba"
208+
result = grouped.agg(**agg_kwargs)
209+
agg_kwargs["engine"] = "cython"
210+
expected = grouped.agg(**agg_kwargs)
211+
if isinstance(expected, DataFrame):
212+
tm.assert_frame_equal(result, expected)
213+
else:
214+
tm.assert_series_equal(result, expected)
215+
216+
217+
@td.skip_if_no("numba")
218+
@pytest.mark.single_cpu
219+
@pytest.mark.parametrize(
220+
"data,agg_kwargs",
221+
[
222+
(Series([1.0, 2.0, 3.0, 4.0, 5.0]), {"func": ["min", "max"]}),
223+
(Series([1.0, 2.0, 3.0, 4.0, 5.0]), {"func": "min"}),
224+
(
225+
DataFrame(
226+
{1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2]
227+
),
228+
{"func": ["min", "max"]},
229+
),
230+
(
231+
DataFrame(
232+
{1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2]
233+
),
234+
{"func": "min"},
235+
),
236+
(
237+
DataFrame(
238+
{1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2]
239+
),
240+
{"func": {1: ["min", "max"], 2: "sum"}},
241+
),
242+
(
243+
DataFrame(
244+
{1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2]
245+
),
246+
{"min_col": NamedAgg(column=1, aggfunc="min")},
247+
),
248+
],
249+
)
250+
def test_multifunc_numba_kwarg_propagation(data, agg_kwargs):
251+
labels = ["a", "a", "b", "b", "a"]
252+
grouped = data.groupby(labels)
253+
result = grouped.agg(**agg_kwargs, engine="numba", engine_kwargs={"parallel": True})
254+
expected = grouped.agg(**agg_kwargs, engine="numba")
255+
if isinstance(expected, DataFrame):
256+
tm.assert_frame_equal(result, expected)
257+
else:
258+
tm.assert_series_equal(result, expected)
156259

157260

158261
@td.skip_if_no("numba")

0 commit comments

Comments
 (0)