Skip to content

Commit ec70c57

Browse files
committed
PERF: fix #32976 slow group by for categorical columns
Aggregate categorical codes with fast cython aggregation for select `how` operations.
1 parent 77a0f19 commit ec70c57

File tree

4 files changed

+67
-2
lines changed

4 files changed

+67
-2
lines changed

asv_bench/benchmarks/groupby.py

+24
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pandas import (
88
Categorical,
9+
CategoricalDtype,
910
DataFrame,
1011
MultiIndex,
1112
Series,
@@ -473,6 +474,7 @@ def time_sum(self):
473474

474475

475476
class Categories:
477+
# benchmark grouping by categoricals
476478
def setup(self):
477479
N = 10 ** 5
478480
arr = np.random.random(N)
@@ -510,6 +512,28 @@ def time_groupby_extra_cat_nosort(self):
510512
self.df_extra_cat.groupby("a", sort=False)["b"].count()
511513

512514

515+
class CategoricalFrame:
516+
# benchmark grouping with operations on categorical values (GH #32976)
517+
def setup(self):
518+
SIZE = 100000
519+
GROUPS = 10000 # The larger, the more extreme the timing differences
520+
CARDINALITY = 10
521+
CAT = CategoricalDtype(list(range(CARDINALITY)))
522+
df_int = DataFrame(
523+
{
524+
"group": [np.random.randint(0, GROUPS) for i in range(SIZE)],
525+
"cat": [np.random.choice(CAT.categories) for i in range(SIZE)],
526+
}
527+
)
528+
self.df_cat_values = df_int.astype({"cat": CAT})
529+
530+
def time_groupby(self):
531+
self.df_cat_values.groupby("group").last()
532+
533+
def time_groupby_ordered(self):
534+
self.df_cat_values.groupby("group", sort=True).last()
535+
536+
513537
class Datelike:
514538
# GH 14338
515539
params = ["period_range", "date_range", "date_range_tz"]

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ Performance improvements
427427
:meth:`DataFrame.sparse.from_spmatrix` constructor (:issue:`32821`,
428428
:issue:`32825`, :issue:`32826`, :issue:`32856`, :issue:`32858`).
429429
- Performance improvement in reductions (sum, prod, min, max) for nullable (integer and boolean) dtypes (:issue:`30982`, :issue:`33261`, :issue:`33442`).
430+
- Performance improvement in :meth:`DataFrame.groupby` when aggregating categorical data (:issue:`32976`)
430431

431432

432433
.. ---------------------------------------------------------------------------

pandas/core/groupby/ops.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from pandas.core.dtypes.missing import _maybe_fill, isna
4040

4141
import pandas.core.algorithms as algorithms
42+
from pandas.core.arrays.categorical import Categorical
4243
from pandas.core.base import SelectionMixin
4344
import pandas.core.common as com
4445
from pandas.core.frame import DataFrame
@@ -451,7 +452,7 @@ def _cython_operation(
451452

452453
# categoricals are only 1d, so we
453454
# are not setup for dim transforming
454-
if is_categorical_dtype(values) or is_sparse(values):
455+
if is_sparse(values):
455456
raise NotImplementedError(f"{values.dtype} dtype not supported")
456457
elif is_datetime64_any_dtype(values):
457458
if how in ["add", "prod", "cumsum", "cumprod"]:
@@ -472,6 +473,29 @@ def _cython_operation(
472473

473474
is_datetimelike = needs_i8_conversion(values.dtype)
474475
is_numeric = is_numeric_dtype(values.dtype)
476+
is_categorical = is_categorical_dtype(values)
477+
cat_method_blacklist = (
478+
"add",
479+
"median",
480+
"prod",
481+
"sem",
482+
"cumsum",
483+
"sum",
484+
"cummin",
485+
"mean",
486+
"max",
487+
"skew",
488+
"cumprod",
489+
"cummax",
490+
"rank",
491+
"pct_change",
492+
"min",
493+
"var",
494+
"mad",
495+
"describe",
496+
"std",
497+
"quantile",
498+
)
475499

476500
if is_datetimelike:
477501
values = values.view("int64")
@@ -487,6 +511,17 @@ def _cython_operation(
487511
values = ensure_int_or_float(values)
488512
elif is_numeric and not is_complex_dtype(values):
489513
values = ensure_float64(values)
514+
elif is_categorical:
515+
if how in cat_method_blacklist:
516+
raise NotImplementedError(
517+
f"{values.dtype} dtype not supported for `how` argument {how}"
518+
)
519+
values, categories, ordered = (
520+
values.codes.astype(np.int64),
521+
values.categories,
522+
values.ordered,
523+
)
524+
is_numeric = True
490525
else:
491526
values = values.astype(object)
492527

@@ -574,6 +609,11 @@ def _cython_operation(
574609
result = type(orig_values)(result.astype(np.int64), dtype=orig_values.dtype)
575610
elif is_datetimelike and kind == "aggregate":
576611
result = result.astype(orig_values.dtype)
612+
elif is_categorical:
613+
# re-create categories
614+
result = Categorical.from_codes(
615+
result, categories=categories, ordered=ordered,
616+
)
577617

578618
return result, names
579619

pandas/tests/groupby/aggregate/test_aggregate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def test_agg_cython_category_not_implemented_fallback():
466466
result = df.groupby("col_num").col_cat.first()
467467
expected = pd.Series(
468468
[1, 2, 3], index=pd.Index([1, 2, 3], name="col_num"), name="col_cat"
469-
)
469+
).astype("category")
470470
tm.assert_series_equal(result, expected)
471471

472472
result = df.groupby("col_num").agg({"col_cat": "first"})

0 commit comments

Comments
 (0)