Skip to content

PERF: fix #32976 slow group by for categorical columns #33739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pandas import (
Categorical,
CategoricalDtype,
DataFrame,
MultiIndex,
Series,
Expand Down Expand Up @@ -473,6 +474,7 @@ def time_sum(self):


class Categories:
# benchmark grouping by categoricals
def setup(self):
N = 10 ** 5
arr = np.random.random(N)
Expand Down Expand Up @@ -510,6 +512,33 @@ def time_groupby_extra_cat_nosort(self):
self.df_extra_cat.groupby("a", sort=False)["b"].count()


class CategoricalFrame:
# benchmark grouping with operations on categorical values (GH #32976)
param_names = ["groupby_type", "value_type", "agg_method"]
params = [(int,), (int, str), ("last", "head", "count")]

def setup(self, groupby_type, value_type, agg_method):
SIZE = 100000
GROUPS = 1000
CARDINALITY = 10
CAT = CategoricalDtype([value_type(i) for i in range(CARDINALITY)])
df = DataFrame(
{
"group": [
groupby_type(np.random.randint(0, GROUPS)) for i in range(SIZE)
],
"cat": [np.random.choice(CAT.categories) for i in range(SIZE)],
}
)
self.df_cat_values = df.astype({"cat": CAT})

def time_groupby(self, groupby_type, value_type, agg_method):
getattr(self.df_cat_values.groupby("group"), agg_method)()

def time_groupby_ordered(self, groupby_type, value_type, agg_method):
getattr(self.df_cat_values.groupby("group", sort=True), agg_method)()


class Datelike:
# GH 14338
params = ["period_range", "date_range", "date_range_tz"]
Expand Down
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Deprecations
Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~

- Performance improvement in :meth:`DataFrame.groupby` when aggregating categorical data (:issue:`32976`)
-
-

Expand Down Expand Up @@ -132,6 +133,7 @@ Plotting
Groupby/resample/rolling
^^^^^^^^^^^^^^^^^^^^^^^^

- :meth:`DataFrame.groupby` aggregations of categorical series will now return a :class:`Categorical` while preserving the codes and categories of the original series (:issue:`33739`)
-
-

Expand Down
43 changes: 42 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pandas.core.dtypes.missing import _maybe_fill, isna

import pandas.core.algorithms as algorithms
from pandas.core.arrays.categorical import Categorical
from pandas.core.base import SelectionMixin
import pandas.core.common as com
from pandas.core.frame import DataFrame
Expand Down Expand Up @@ -356,6 +357,29 @@ def get_group_levels(self) -> List[Index]:

_name_functions = {"ohlc": ["open", "high", "low", "close"]}

_cat_method_blacklist = (
"add",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what methods don't work?

"median",
"prod",
"sem",
"cumsum",
"sum",
"cummin",
"mean",
"max",
"skew",
"cumprod",
"cummax",
"rank",
"pct_change",
"min",
"var",
"mad",
"describe",
"std",
"quantile",
)

def _is_builtin_func(self, arg):
"""
if we define a builtin function for this argument, return it,
Expand Down Expand Up @@ -460,7 +484,7 @@ def _cython_operation(

# categoricals are only 1d, so we
# are not setup for dim transforming
if is_categorical_dtype(values.dtype) or is_sparse(values.dtype):
if is_sparse(values.dtype):
raise NotImplementedError(f"{values.dtype} dtype not supported")
elif is_datetime64_any_dtype(values.dtype):
if how in ["add", "prod", "cumsum", "cumprod"]:
Expand All @@ -481,6 +505,7 @@ def _cython_operation(

is_datetimelike = needs_i8_conversion(values.dtype)
is_numeric = is_numeric_dtype(values.dtype)
is_categorical = is_categorical_dtype(values)

if is_datetimelike:
values = values.view("int64")
Expand All @@ -496,6 +521,17 @@ def _cython_operation(
values = ensure_int_or_float(values)
elif is_numeric and not is_complex_dtype(values):
values = ensure_float64(values)
elif is_categorical:
if how in self._cat_method_blacklist:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't like doing this. Can you elaborate when we can actually process this? listing methods is a bad idea generally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exactly sure what you mean by "when we can actually process this" this but I agree that listing methods isn't necessarily thorough and isn't robust. However, I've been unable to find a suitable alternative to blacklisting methods where we don't want to apply the aggregation on the category codes -- open to other ideas though. Please find more context in this thread.

raise NotImplementedError(
f"{values.dtype} dtype not supported for `how` argument {how}"
)
values, categories, ordered = (
values.codes.astype(np.int64),
values.categories,
values.ordered,
)
is_numeric = True
else:
values = values.astype(object)

Expand Down Expand Up @@ -572,6 +608,11 @@ def _cython_operation(
result = type(orig_values)(result.astype(np.int64), dtype=orig_values.dtype)
elif is_datetimelike and kind == "aggregate":
result = result.astype(orig_values.dtype)
elif is_categorical:
# re-create categories
result = Categorical.from_codes(
result, categories=categories, ordered=ordered,
)

if is_extension_array_dtype(orig_values.dtype):
result = maybe_cast_result(result=result, obj=orig_values, how=how)
Expand Down
6 changes: 4 additions & 2 deletions pandas/tests/groupby/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,7 +1505,9 @@ def test_groupy_first_returned_categorical_instead_of_dataframe(func):
)
df_grouped = df.groupby("A")["B"]
result = getattr(df_grouped, func)()
expected = pd.Series(["b"], index=pd.Index([1997], name="A"), name="B")
expected = pd.Series(
["b"], index=pd.Index([1997], name="A"), name="B", dtype="category"
).cat.as_ordered()
tm.assert_series_equal(result, expected)


Expand Down Expand Up @@ -1574,7 +1576,7 @@ def test_agg_cython_category_not_implemented_fallback():
result = df.groupby("col_num").col_cat.first()
expected = pd.Series(
[1, 2, 3], index=pd.Index([1, 2, 3], name="col_num"), name="col_cat"
)
).astype("category")
tm.assert_series_equal(result, expected)

result = df.groupby("col_num").agg({"col_cat": "first"})
Expand Down