Skip to content

Commit 0fd3dde

Browse files
committed
PERF: fix #32976 slow group by for categorical columns
Aggregate categorical codes with fast cython aggregation for select `how` operations. 8/1/20: rebase and move release note to 1.2 8/2/20: Update tests to expect categorical back 8/3/20: add PR as issue for whatsnew groupby api change
1 parent cda8284 commit 0fd3dde

File tree

4 files changed

+77
-3
lines changed

4 files changed

+77
-3
lines changed

asv_bench/benchmarks/groupby.py

+29
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pandas import (
88
Categorical,
9+
CategoricalDtype,
910
DataFrame,
1011
MultiIndex,
1112
Series,
@@ -473,6 +474,7 @@ def time_sum(self):
473474

474475

475476
class Categories:
477+
# benchmark grouping by categoricals
476478
def setup(self):
477479
N = 10 ** 5
478480
arr = np.random.random(N)
@@ -510,6 +512,33 @@ def time_groupby_extra_cat_nosort(self):
510512
self.df_extra_cat.groupby("a", sort=False)["b"].count()
511513

512514

515+
class CategoricalFrame:
516+
# benchmark grouping with operations on categorical values (GH #32976)
517+
param_names = ["groupby_type", "value_type", "agg_method"]
518+
params = [(int,), (int, str), ("last", "head", "count")]
519+
520+
def setup(self, groupby_type, value_type, agg_method):
521+
SIZE = 100000
522+
GROUPS = 1000
523+
CARDINALITY = 10
524+
CAT = CategoricalDtype([value_type(i) for i in range(CARDINALITY)])
525+
df = DataFrame(
526+
{
527+
"group": [
528+
groupby_type(np.random.randint(0, GROUPS)) for i in range(SIZE)
529+
],
530+
"cat": [np.random.choice(CAT.categories) for i in range(SIZE)],
531+
}
532+
)
533+
self.df_cat_values = df.astype({"cat": CAT})
534+
535+
def time_groupby(self, groupby_type, value_type, agg_method):
536+
getattr(self.df_cat_values.groupby("group"), agg_method)()
537+
538+
def time_groupby_ordered(self, groupby_type, value_type, agg_method):
539+
getattr(self.df_cat_values.groupby("group", sort=True), agg_method)()
540+
541+
513542
class Datelike:
514543
# GH 14338
515544
params = ["period_range", "date_range", "date_range_tz"]

doc/source/whatsnew/v1.2.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Deprecations
4040
Performance improvements
4141
~~~~~~~~~~~~~~~~~~~~~~~~
4242

43+
- Performance improvement in :meth:`DataFrame.groupby` when aggregating categorical data (:issue:`32976`)
4344
-
4445
-
4546

@@ -132,6 +133,7 @@ Plotting
132133
Groupby/resample/rolling
133134
^^^^^^^^^^^^^^^^^^^^^^^^
134135

136+
- :meth:`DataFrame.groupby` aggregations of categorical series will now return a :class:`Categorical` while preserving the codes and categories of the original series (:issue:`33739`)
135137
-
136138
-
137139

pandas/core/groupby/ops.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from pandas.core.dtypes.missing import _maybe_fill, isna
4141

4242
import pandas.core.algorithms as algorithms
43+
from pandas.core.arrays.categorical import Categorical
4344
from pandas.core.base import SelectionMixin
4445
import pandas.core.common as com
4546
from pandas.core.frame import DataFrame
@@ -356,6 +357,29 @@ def get_group_levels(self) -> List[Index]:
356357

357358
_name_functions = {"ohlc": ["open", "high", "low", "close"]}
358359

360+
_cat_method_blacklist = (
361+
"add",
362+
"median",
363+
"prod",
364+
"sem",
365+
"cumsum",
366+
"sum",
367+
"cummin",
368+
"mean",
369+
"max",
370+
"skew",
371+
"cumprod",
372+
"cummax",
373+
"rank",
374+
"pct_change",
375+
"min",
376+
"var",
377+
"mad",
378+
"describe",
379+
"std",
380+
"quantile",
381+
)
382+
359383
def _is_builtin_func(self, arg):
360384
"""
361385
if we define a builtin function for this argument, return it,
@@ -460,7 +484,7 @@ def _cython_operation(
460484

461485
# categoricals are only 1d, so we
462486
# are not setup for dim transforming
463-
if is_categorical_dtype(values.dtype) or is_sparse(values.dtype):
487+
if is_sparse(values.dtype):
464488
raise NotImplementedError(f"{values.dtype} dtype not supported")
465489
elif is_datetime64_any_dtype(values.dtype):
466490
if how in ["add", "prod", "cumsum", "cumprod"]:
@@ -481,6 +505,7 @@ def _cython_operation(
481505

482506
is_datetimelike = needs_i8_conversion(values.dtype)
483507
is_numeric = is_numeric_dtype(values.dtype)
508+
is_categorical = is_categorical_dtype(values)
484509

485510
if is_datetimelike:
486511
values = values.view("int64")
@@ -496,6 +521,17 @@ def _cython_operation(
496521
values = ensure_int_or_float(values)
497522
elif is_numeric and not is_complex_dtype(values):
498523
values = ensure_float64(values)
524+
elif is_categorical:
525+
if how in self._cat_method_blacklist:
526+
raise NotImplementedError(
527+
f"{values.dtype} dtype not supported for `how` argument {how}"
528+
)
529+
values, categories, ordered = (
530+
values.codes.astype(np.int64),
531+
values.categories,
532+
values.ordered,
533+
)
534+
is_numeric = True
499535
else:
500536
values = values.astype(object)
501537

@@ -572,6 +608,11 @@ def _cython_operation(
572608
result = type(orig_values)(result.astype(np.int64), dtype=orig_values.dtype)
573609
elif is_datetimelike and kind == "aggregate":
574610
result = result.astype(orig_values.dtype)
611+
elif is_categorical:
612+
# re-create categories
613+
result = Categorical.from_codes(
614+
result, categories=categories, ordered=ordered,
615+
)
575616

576617
if is_extension_array_dtype(orig_values.dtype):
577618
result = maybe_cast_result(result=result, obj=orig_values, how=how)

pandas/tests/groupby/test_categorical.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1505,7 +1505,9 @@ def test_groupy_first_returned_categorical_instead_of_dataframe(func):
15051505
)
15061506
df_grouped = df.groupby("A")["B"]
15071507
result = getattr(df_grouped, func)()
1508-
expected = pd.Series(["b"], index=pd.Index([1997], name="A"), name="B")
1508+
expected = pd.Series(
1509+
["b"], index=pd.Index([1997], name="A"), name="B", dtype="category"
1510+
).cat.as_ordered()
15091511
tm.assert_series_equal(result, expected)
15101512

15111513

@@ -1574,7 +1576,7 @@ def test_agg_cython_category_not_implemented_fallback():
15741576
result = df.groupby("col_num").col_cat.first()
15751577
expected = pd.Series(
15761578
[1, 2, 3], index=pd.Index([1, 2, 3], name="col_num"), name="col_cat"
1577-
)
1579+
).astype("category")
15781580
tm.assert_series_equal(result, expected)
15791581

15801582
result = df.groupby("col_num").agg({"col_cat": "first"})

0 commit comments

Comments
 (0)