Skip to content

BUG: groupby().agg fails on categorical column #31470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7e461a1
remove \n from docstring
charlesdong1991 Dec 3, 2018
1314059
fix conflicts
charlesdong1991 Jan 19, 2019
8bcb313
Merge remote-tracking branch 'upstream/master'
charlesdong1991 Jul 30, 2019
24c3ede
Merge remote-tracking branch 'upstream/master'
charlesdong1991 Jan 14, 2020
dea38f2
fix issue 17038
charlesdong1991 Jan 14, 2020
cd9e7ac
revert change
charlesdong1991 Jan 14, 2020
e5e912b
revert change
charlesdong1991 Jan 14, 2020
97f266f
Merge remote-tracking branch 'upstream/master' into issue_31450
charlesdong1991 Jan 30, 2020
93ebadb
try fix
charlesdong1991 Jan 30, 2020
3520b95
upload test
charlesdong1991 Jan 30, 2020
32cc744
linting
charlesdong1991 Jan 30, 2020
9f936cc
broader concept
charlesdong1991 Jan 30, 2020
946c49f
fix up
charlesdong1991 Jan 30, 2020
73b01c6
imports
charlesdong1991 Jan 30, 2020
2fdb3f5
keep experimenting
charlesdong1991 Jan 30, 2020
9e52c70
fixtup
charlesdong1991 Jan 30, 2020
a366b02
add comment
charlesdong1991 Jan 30, 2020
bdfcfab
Merge remote-tracking branch 'upstream/master' into issue_31450
charlesdong1991 Jan 31, 2020
36184f6
experiment
charlesdong1991 Feb 1, 2020
9d4e021
update
charlesdong1991 Feb 1, 2020
c588204
change base
charlesdong1991 Feb 1, 2020
a11279d
experiment
charlesdong1991 Feb 1, 2020
bb3ff98
experiment
charlesdong1991 Feb 1, 2020
5d0bcfd
experiment
charlesdong1991 Feb 1, 2020
cc516c8
experiemnt
charlesdong1991 Feb 1, 2020
3c5c3aa
experiment
charlesdong1991 Feb 3, 2020
a63e65d
fixup
charlesdong1991 Feb 3, 2020
4ba67e8
experiment
charlesdong1991 Feb 3, 2020
849f96f
experiment
charlesdong1991 Feb 3, 2020
50a7242
experiment
charlesdong1991 Feb 3, 2020
6635d31
experiment
charlesdong1991 Feb 3, 2020
b55b6b4
fixup and linting
charlesdong1991 Feb 3, 2020
5dd9b38
Merge remote-tracking branch 'upstream/master' into issue_31450
charlesdong1991 Feb 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pandas/core/groupby/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def _gotitem(self, key, ndim, subset=None):

cython_cast_blacklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])

cython_cast_cat_type_list = frozenset(["first", "last"])
cython_cast_keep_type_list = cython_cast_cat_type_list | frozenset(
["sum", "min", "max"]
)

# List of aggregation/reduction functions.
# These map each group to a single numeric value
reduction_kernels = frozenset(
Expand Down
23 changes: 15 additions & 8 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class providing the base-class of operations.
from pandas.core.frame import DataFrame
from pandas.core.generic import NDFrame
from pandas.core.groupby import base, ops
from pandas.core.groupby.base import cython_cast_keep_type_list
from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex
from pandas.core.series import Series
from pandas.core.sorting import get_group_index_sorter
Expand Down Expand Up @@ -792,7 +793,7 @@ def _cumcount_array(self, ascending: bool = True):
rev[sorter] = np.arange(count, dtype=np.intp)
return out[rev].astype(np.int64, copy=False)

def _try_cast(self, result, obj, numeric_only: bool = False):
def _try_cast(self, result, obj, numeric_only: bool = False, how=None):
"""
Try to cast the result to our obj original type,
we may have roundtripped through object in the mean-time.
Expand All @@ -807,13 +808,19 @@ def _try_cast(self, result, obj, numeric_only: bool = False):
dtype = obj.dtype

if not is_scalar(result):
if is_extension_array_dtype(dtype) and dtype.kind != "M":
# The function can return something of any type, so check
# if the type is compatible with the calling EA.
# datetime64tz is handled correctly in agg_series,
# so is excluded here.

if len(result) and isinstance(result[0], dtype.type):
# The function can return something of any type, so check
# if the type is compatible with the calling EA.
# datetime64tz is handled correctly in agg_series,
# so is excluded here.
if is_extension_array_dtype(dtype) and dtype.kind != "M":
# if how is in cython_cast_keep_type_list, which means it
# should be cast back to return the same type as obj
if (
len(result)
and isinstance(result[0], dtype.type)
or how in cython_cast_keep_type_list
):
cls = dtype.construct_array_type()
result = try_cast_to_ea(cls, result, dtype=dtype)

Expand Down Expand Up @@ -900,7 +907,7 @@ def _cython_agg_general(
else:
assert result.ndim == 1
key = base.OutputKey(label=name, position=idx)
output[key] = self._try_cast(result, obj)
output[key] = self._try_cast(result, obj, how=how)
idx += 1

if len(output) == 0:
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pandas.core.frame import DataFrame
from pandas.core.generic import NDFrame
from pandas.core.groupby import base, grouper
from pandas.core.groupby.base import cython_cast_cat_type_list
from pandas.core.indexes.api import Index, MultiIndex, ensure_index
from pandas.core.series import Series
from pandas.core.sorting import (
Expand Down Expand Up @@ -451,7 +452,12 @@ def _cython_operation(

# categoricals are only 1d, so we
# are not setup for dim transforming
if is_categorical_dtype(values) or is_sparse(values):
# those four cython agg that should work with categoricals
if (
is_categorical_dtype(values)
and how not in cython_cast_cat_type_list
or is_sparse(values)
):
raise NotImplementedError(f"{values.dtype} dtype not supported")
elif is_datetime64_any_dtype(values):
if how in ["add", "prod", "cumsum", "cumprod"]:
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/groupby/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,3 +1376,14 @@ def test_groupby_agg_non_numeric():

result = df.groupby([1, 2, 1]).nunique()
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("func", ["first", "last"])
def test_groupby_agg_categorical_first_last(func):
# GH 31450
df = pd.DataFrame({"col_num": [1, 1, 2, 3]})
df["col_cat"] = df["col_num"].astype("category")

grouped = df.groupby("col_num").agg({"col_cat": func})
expected = df.groupby("col_num").agg(func)
tm.assert_frame_equal(grouped, expected)