Skip to content

Commit 6bf6cd2

Browse files
5hirishjreback
authored andcommitted
BUG: GroupBy return EA dtype (#23318)
1 parent bd98841 commit 6bf6cd2

File tree

5 files changed

+45
-17
lines changed

5 files changed

+45
-17
lines changed

doc/source/whatsnew/v0.24.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
854854
- :func:`ExtensionArray.isna` is allowed to return an ``ExtensionArray`` (:issue:`22325`).
855855
- Support for reduction operations such as ``sum``, ``mean`` via opt-in base class method override (:issue:`22762`)
856856
- :meth:`Series.unstack` no longer converts extension arrays to object-dtype ndarrays. The output ``DataFrame`` will now have the same dtype as the input. This changes behavior for Categorical and Sparse data (:issue:`23077`).
857+
- Bug when grouping :meth:`Dataframe.groupby()` and aggregating on ``ExtensionArray`` it was not returning the actual ``ExtensionArray`` dtype (:issue:`23227`).
857858

858859
.. _whatsnew_0240.api.incompatibilities:
859860

@@ -1089,6 +1090,7 @@ Categorical
10891090
- Bug when indexing with a boolean-valued ``Categorical``. Now a boolean-valued ``Categorical`` is treated as a boolean mask (:issue:`22665`)
10901091
- Constructing a :class:`CategoricalIndex` with empty values and boolean categories was raising a ``ValueError`` after a change to dtype coercion (:issue:`22702`).
10911092
- Bug in :meth:`Categorical.take` with a user-provided ``fill_value`` not encoding the ``fill_value``, which could result in a ``ValueError``, incorrect results, or a segmentation fault (:issue:`23296`).
1093+
- Bug when resampling :meth:`Dataframe.resample()` and aggregating on categorical data, the categorical dtype was getting lost. (:issue:`23227`)
10921094

10931095
Datetimelike
10941096
^^^^^^^^^^^^

pandas/core/groupby/groupby.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ class providing the base-class of operations.
2424
from pandas.util._validators import validate_kwargs
2525

2626
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
27-
from pandas.core.dtypes.common import ensure_float, is_numeric_dtype, is_scalar
27+
from pandas.core.dtypes.common import (
28+
ensure_float, is_extension_array_dtype, is_numeric_dtype, is_scalar)
2829
from pandas.core.dtypes.missing import isna, notna
2930

3031
import pandas.core.algorithms as algorithms
@@ -754,7 +755,18 @@ def _try_cast(self, result, obj, numeric_only=False):
754755
dtype = obj.dtype
755756

756757
if not is_scalar(result):
757-
if numeric_only and is_numeric_dtype(dtype) or not numeric_only:
758+
if is_extension_array_dtype(dtype):
759+
# The function can return something of any type, so check
760+
# if the type is compatible with the calling EA.
761+
try:
762+
result = obj.values._from_sequence(result)
763+
except Exception:
764+
# https://github.com/pandas-dev/pandas/issues/22850
765+
# pandas has no control over what 3rd-party ExtensionArrays
766+
# do in _values_from_sequence. We still want ops to work
767+
# though, so we catch any regular Exception.
768+
pass
769+
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
758770
result = maybe_downcast_to_dtype(result, dtype)
759771

760772
return result

pandas/tests/arrays/test_integer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,10 @@ def test_preserve_dtypes(op):
650650

651651
# groupby
652652
result = getattr(df.groupby("A"), op)()
653+
653654
expected = pd.DataFrame({
654655
"B": np.array([1.0, 3.0]),
655-
"C": np.array([1, 3], dtype="int64")
656+
"C": integer_array([1, 3], dtype="Int64")
656657
}, index=pd.Index(['a', 'b'], name='A'))
657658
tm.assert_frame_equal(result, expected)
658659

@@ -673,9 +674,10 @@ def test_reduce_to_float(op):
673674

674675
# groupby
675676
result = getattr(df.groupby("A"), op)()
677+
676678
expected = pd.DataFrame({
677679
"B": np.array([1.0, 3.0]),
678-
"C": np.array([1, 3], dtype="float64")
680+
"C": integer_array([1, 3], dtype="Int64")
679681
}, index=pd.Index(['a', 'b'], name='A'))
680682
tm.assert_frame_equal(result, expected)
681683

pandas/tests/sparse/test_groupby.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,39 @@ def test_first_last_nth(self):
2424
sparse_grouped = self.sparse.groupby('A')
2525
dense_grouped = self.dense.groupby('A')
2626

27+
sparse_grouped_first = sparse_grouped.first()
28+
sparse_grouped_last = sparse_grouped.last()
29+
sparse_grouped_nth = sparse_grouped.nth(1)
30+
31+
dense_grouped_first = dense_grouped.first().to_sparse()
32+
dense_grouped_last = dense_grouped.last().to_sparse()
33+
dense_grouped_nth = dense_grouped.nth(1).to_sparse()
34+
2735
# TODO: shouldn't these all be spares or not?
28-
tm.assert_frame_equal(sparse_grouped.first(),
29-
dense_grouped.first())
30-
tm.assert_frame_equal(sparse_grouped.last(),
31-
dense_grouped.last())
32-
tm.assert_frame_equal(sparse_grouped.nth(1),
33-
dense_grouped.nth(1).to_sparse())
36+
tm.assert_frame_equal(sparse_grouped_first,
37+
dense_grouped_first)
38+
tm.assert_frame_equal(sparse_grouped_last,
39+
dense_grouped_last)
40+
tm.assert_frame_equal(sparse_grouped_nth,
41+
dense_grouped_nth)
3442

3543
def test_aggfuncs(self):
3644
sparse_grouped = self.sparse.groupby('A')
3745
dense_grouped = self.dense.groupby('A')
3846

39-
tm.assert_frame_equal(sparse_grouped.mean(),
40-
dense_grouped.mean())
47+
result = sparse_grouped.mean().to_sparse()
48+
expected = dense_grouped.mean().to_sparse()
49+
50+
tm.assert_frame_equal(result, expected)
4151

4252
# ToDo: sparse sum includes str column
4353
# tm.assert_frame_equal(sparse_grouped.sum(),
4454
# dense_grouped.sum())
4555

46-
tm.assert_frame_equal(sparse_grouped.count(),
47-
dense_grouped.count())
56+
result = sparse_grouped.count().to_sparse()
57+
expected = dense_grouped.count().to_sparse()
58+
59+
tm.assert_frame_equal(result, expected)
4860

4961

5062
@pytest.mark.parametrize("fill_value", [0, np.nan])
@@ -54,6 +66,5 @@ def test_groupby_includes_fill_value(fill_value):
5466
'b': [fill_value, 1, fill_value, fill_value]})
5567
sdf = df.to_sparse(fill_value=fill_value)
5668
result = sdf.groupby('a').sum()
57-
expected = df.groupby('a').sum()
58-
tm.assert_frame_equal(result, expected,
59-
check_index_type=False)
69+
expected = df.groupby('a').sum().to_sparse(fill_value=fill_value)
70+
tm.assert_frame_equal(result, expected, check_index_type=False)

pandas/tests/test_resample.py

+1
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,7 @@ def test_resample_categorical_data_with_timedeltaindex(self):
15761576
'Group': ['A', 'A']},
15771577
index=pd.to_timedelta([0, 10], unit='s'))
15781578
expected = expected.reindex(['Group_obj', 'Group'], axis=1)
1579+
expected['Group'] = expected['Group_obj'].astype('category')
15791580
tm.assert_frame_equal(result, expected)
15801581

15811582
def test_resample_daily_anchored(self):

0 commit comments

Comments
 (0)