Skip to content

ENH: Groupby agg support multiple funcs numba #53486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Other enhancements
- Performance improvement in :func:`read_csv` (:issue:`52632`) with ``engine="c"``
- :meth:`Categorical.from_codes` has gotten a ``validate`` parameter (:issue:`50975`)
- :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
- :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`)
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)
Expand Down
14 changes: 11 additions & 3 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,18 @@ def agg_dict_like(self) -> DataFrame | Series:
and selected_obj.columns.nunique() < len(selected_obj.columns)
)

# Numba Groupby engine/engine-kwargs passthrough
kwargs = {}
if is_groupby:
engine = self.kwargs.get("engine", None)
engine_kwargs = self.kwargs.get("engine_kwargs", None)
kwargs = {"engine": engine, "engine_kwargs": engine_kwargs}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the method signatures with by adding engine and engine_kwargs params instead of mangling kwargs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obj here can be a Series/DF/Resampler/ too, since the code is shared with there, so it wouldn't be possible to add engine/engine_kwargs there (unless we supported numba there too), since agg there is part of the public API.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obj here can be a Series/DF/Resampler/ too, since the code is shared with there, so it wouldn't be possible to add engine/engine_kwargs there (unless we supported numba there too), since agg there is part of the public API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more we evolve groupby, the more I realize trying to share this code between Series/DataFrame and groupby was a mistake of mine. I've been thinking we should refactor and separate off groupby entirely (but it can still live in core.apply). I'm okay with having a bit of a kludge here and will separate/cleanup after.


with context_manager:
if selected_obj.ndim == 1:
# key only used for output
colg = obj._gotitem(selection, ndim=1)
result_data = [colg.agg(how) for _, how in func.items()]
result_data = [colg.agg(how, **kwargs) for _, how in func.items()]
result_index = list(func.keys())
elif is_non_unique_col:
# key used for column selection and output
Expand All @@ -422,7 +429,7 @@ def agg_dict_like(self) -> DataFrame | Series:
label_to_indices[label].append(index)

key_data = [
selected_obj._ixs(indice, axis=1).agg(how)
selected_obj._ixs(indice, axis=1).agg(how, **kwargs)
for label, indices in label_to_indices.items()
for indice in indices
]
Expand All @@ -432,7 +439,8 @@ def agg_dict_like(self) -> DataFrame | Series:
else:
# key used for column selection and output
result_data = [
obj._gotitem(key, ndim=1).agg(how) for key, how in func.items()
obj._gotitem(key, ndim=1).agg(how, **kwargs)
for key, how in func.items()
]
result_index = list(func.keys())

Expand Down
39 changes: 29 additions & 10 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,24 +223,26 @@ def apply(self, func, *args, **kwargs) -> Series:

@doc(_agg_template_series, examples=_agg_examples_doc, klass="Series")
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
if maybe_use_numba(engine):
return self._aggregate_with_numba(
func, *args, engine_kwargs=engine_kwargs, **kwargs
)

relabeling = func is None
columns = None
if relabeling:
columns, func = validate_func_kwargs(kwargs)
kwargs = {}

if isinstance(func, str):
if maybe_use_numba(engine):
# Not all agg functions support numba, only propagate numba kwargs
# if user asks for numba
kwargs["engine"] = engine
kwargs["engine_kwargs"] = engine_kwargs
return getattr(self, func)(*args, **kwargs)

elif isinstance(func, abc.Iterable):
# Catch instances of lists / tuples
# but not the class list / tuple itself.
func = maybe_mangle_lambdas(func)
kwargs["engine"] = engine
kwargs["engine_kwargs"] = engine_kwargs
ret = self._aggregate_multiple_funcs(func, *args, **kwargs)
if relabeling:
# columns is not narrowed by mypy from relabeling flag
Expand All @@ -255,6 +257,11 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
if cyfunc and not args and not kwargs:
return getattr(self, cyfunc)()

if maybe_use_numba(engine):
return self._aggregate_with_numba(
func, *args, engine_kwargs=engine_kwargs, **kwargs
)

if self.ngroups == 0:
# e.g. test_evaluate_with_empty_groups without any groups to
# iterate over, we have no output on which to do dtype
Expand Down Expand Up @@ -1387,14 +1394,15 @@ class DataFrameGroupBy(GroupBy[DataFrame]):

@doc(_agg_template_frame, examples=_agg_examples_doc, klass="DataFrame")
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
if maybe_use_numba(engine):
return self._aggregate_with_numba(
func, *args, engine_kwargs=engine_kwargs, **kwargs
)

relabeling, func, columns, order = reconstruct_func(func, **kwargs)
func = maybe_mangle_lambdas(func)

if maybe_use_numba(engine):
# Not all agg functions support numba, only propagate numba kwargs
# if user asks for numba
kwargs["engine"] = engine
kwargs["engine_kwargs"] = engine_kwargs

op = GroupByApply(self, func, args=args, kwargs=kwargs)
result = op.agg()
if not is_dict_like(func) and result is not None:
Expand All @@ -1416,6 +1424,17 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
result.columns = columns # type: ignore[assignment]

if result is None:
# Remove the kwargs we inserted
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we need to go through result = op.agg() first instead of just using self._aggregate_with_numba in the if if maybe_use_numba(engine) branch above?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that checks for str/list/dict aggregations, and call back into the agg function itself.

It'll return None, if a callable is passed, though (what this block handles).

# (already stored in engine, engine_kwargs arguments)
if "engine" in kwargs:
del kwargs["engine"]
del kwargs["engine_kwargs"]
# at this point func is not a str, list-like, dict-like,
# or a known callable(e.g. sum)
if maybe_use_numba(engine):
return self._aggregate_with_numba(
func, *args, engine_kwargs=engine_kwargs, **kwargs
)
# grouper specific aggregations
if self.grouper.nkeys > 1:
# test_groupby_as_index_series_scalar gets here with 'not self.as_index'
Expand Down
126 changes: 115 additions & 11 deletions pandas/tests/groupby/aggregate/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,25 +134,129 @@ def func_1(values, index):


@td.skip_if_no("numba")
# TODO: Add test to check that UDF is still jitted by numba
@pytest.mark.parametrize(
"agg_func",
"agg_kwargs",
[
["min", "max"],
"min",
{"B": ["min", "max"], "C": "sum"},
NamedAgg(column="B", aggfunc="min"),
{"func": ["min", "max"]},
{"func": "min"},
{"func": {1: ["min", "max"], 2: "sum"}},
{"bmin": NamedAgg(column=1, aggfunc="min")},
],
)
def test_multifunc_notimplimented(agg_func):
def test_multifunc_numba_vs_cython_frame(agg_kwargs):
data = DataFrame(
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
{
0: ["a", "a", "b", "b", "a"],
1: [1.0, 2.0, 3.0, 4.0, 5.0],
2: [1, 2, 3, 4, 5],
},
columns=[0, 1, 2],
)
grouped = data.groupby(0)
result = grouped.agg(**agg_kwargs, engine="numba")
expected = grouped.agg(**agg_kwargs, engine="cython")
# check_dtype can be removed if GH 44952 is addressed
tm.assert_frame_equal(result, expected, check_dtype=False)


@td.skip_if_no("numba")
@pytest.mark.parametrize(
"agg_kwargs,expected_func",
[
({"func": lambda values, index: values.sum()}, "sum"),
# FIXME
pytest.param(
{
"func": [
lambda values, index: values.sum(),
lambda values, index: values.min(),
]
},
["sum", "min"],
marks=pytest.mark.xfail(
reason="This doesn't work yet! Fails in nopython pipeline!"
),
),
],
)
def test_multifunc_numba_udf_frame(agg_kwargs, expected_func):
data = DataFrame(
{
0: ["a", "a", "b", "b", "a"],
1: [1.0, 2.0, 3.0, 4.0, 5.0],
2: [1, 2, 3, 4, 5],
},
columns=[0, 1, 2],
)
grouped = data.groupby(0)
with pytest.raises(NotImplementedError, match="Numba engine can"):
grouped.agg(agg_func, engine="numba")
result = grouped.agg(**agg_kwargs, engine="numba")
expected = grouped.agg(expected_func, engine="cython")
# check_dtype can be removed if GH 44952 is addressed
tm.assert_frame_equal(result, expected, check_dtype=False)

with pytest.raises(NotImplementedError, match="Numba engine can"):
grouped[1].agg(agg_func, engine="numba")

@td.skip_if_no("numba")
@pytest.mark.parametrize(
"agg_kwargs",
[{"func": ["min", "max"]}, {"func": "min"}, {"min_val": "min", "max_val": "max"}],
)
def test_multifunc_numba_vs_cython_series(agg_kwargs):
labels = ["a", "a", "b", "b", "a"]
data = Series([1.0, 2.0, 3.0, 4.0, 5.0])
grouped = data.groupby(labels)
agg_kwargs["engine"] = "numba"
result = grouped.agg(**agg_kwargs)
agg_kwargs["engine"] = "cython"
expected = grouped.agg(**agg_kwargs)
if isinstance(expected, DataFrame):
tm.assert_frame_equal(result, expected)
else:
tm.assert_series_equal(result, expected)


@td.skip_if_no("numba")
@pytest.mark.single_cpu
@pytest.mark.parametrize(
"data,agg_kwargs",
[
(Series([1.0, 2.0, 3.0, 4.0, 5.0]), {"func": ["min", "max"]}),
(Series([1.0, 2.0, 3.0, 4.0, 5.0]), {"func": "min"}),
(
DataFrame(
{1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2]
),
{"func": ["min", "max"]},
),
(
DataFrame(
{1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2]
),
{"func": "min"},
),
(
DataFrame(
{1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2]
),
{"func": {1: ["min", "max"], 2: "sum"}},
),
(
DataFrame(
{1: [1.0, 2.0, 3.0, 4.0, 5.0], 2: [1, 2, 3, 4, 5]}, columns=[1, 2]
),
{"min_col": NamedAgg(column=1, aggfunc="min")},
),
],
)
def test_multifunc_numba_kwarg_propagation(data, agg_kwargs):
labels = ["a", "a", "b", "b", "a"]
grouped = data.groupby(labels)
result = grouped.agg(**agg_kwargs, engine="numba", engine_kwargs={"parallel": True})
expected = grouped.agg(**agg_kwargs, engine="numba")
if isinstance(expected, DataFrame):
tm.assert_frame_equal(result, expected)
else:
tm.assert_series_equal(result, expected)


@td.skip_if_no("numba")
Expand Down