Skip to content

Commit 06ef193

Browse files
BUG: regression when applying groupby aggregation on categorical columns (#31359)
1 parent a05e6c9 commit 06ef193

File tree

7 files changed

+122
-5
lines changed

7 files changed

+122
-5
lines changed

doc/source/whatsnew/v1.0.0.rst

+48
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,54 @@ consistent with the behaviour of :class:`DataFrame` and :class:`Index`.
626626
DeprecationWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
627627
Series([], dtype: float64)
628628
629+
Result dtype inference changes for resample operations
630+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
631+
632+
The rules for the result dtype in :meth:`DataFrame.resample` aggregations have changed for extension types (:issue:`31359`).
633+
Previously, pandas would attempt to convert the result back to the original dtype, falling back to the usual
634+
inference rules if that was not possible. Now, pandas will only return a result of the original dtype if the
635+
scalar values in the result are instances of the extension dtype's scalar type.
636+
637+
.. ipython:: python
638+
639+
df = pd.DataFrame({"A": ['a', 'b']}, dtype='category',
640+
index=pd.date_range('2000', periods=2))
641+
df
642+
643+
644+
*pandas 0.25.x*
645+
646+
.. code-block:: python
647+
648+
>>> df.resample("2D").agg(lambda x: 'a').A.dtype
649+
CategoricalDtype(categories=['a', 'b'], ordered=False)
650+
651+
*pandas 1.0.0*
652+
653+
.. ipython:: python
654+
655+
df.resample("2D").agg(lambda x: 'a').A.dtype
656+
657+
This fixes an inconsistency between ``resample`` and ``groupby``.
658+
This also fixes a potential bug, where the **values** of the result might change
659+
depending on how the results are cast back to the original dtype.
660+
661+
*pandas 0.25.x*
662+
663+
.. code-block:: python
664+
665+
>>> df.resample("2D").agg(lambda x: 'c')
666+
667+
A
668+
0 NaN
669+
670+
*pandas 1.0.0*
671+
672+
.. ipython:: python
673+
674+
df.resample("2D").agg(lambda x: 'c')
675+
676+
629677
.. _whatsnew_100.api_breaking.python:
630678

631679
Increased minimum version for Python

pandas/core/groupby/groupby.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,10 @@ def _try_cast(self, result, obj, numeric_only: bool = False):
813813
# datetime64tz is handled correctly in agg_series,
814814
# so is excluded here.
815815

816-
# return the same type (Series) as our caller
817-
cls = dtype.construct_array_type()
818-
result = try_cast_to_ea(cls, result, dtype=dtype)
816+
if len(result) and isinstance(result[0], dtype.type):
817+
cls = dtype.construct_array_type()
818+
result = try_cast_to_ea(cls, result, dtype=dtype)
819+
819820
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
820821
result = maybe_downcast_to_dtype(result, dtype)
821822

pandas/core/groupby/ops.py

+11
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,17 @@ def _cython_operation(
543543
if mask.any():
544544
result = result.astype("float64")
545545
result[mask] = np.nan
546+
elif (
547+
how == "add"
548+
and is_integer_dtype(orig_values.dtype)
549+
and is_extension_array_dtype(orig_values.dtype)
550+
):
551+
# We need this to ensure that Series[Int64Dtype].resample().sum()
552+
# remains int64 dtype.
553+
# Two options for avoiding this special case
554+
# 1. mask-aware ops and avoid casting to float with NaN above
555+
# 2. specify the result dtype when calling this method
556+
result = result.astype("int64")
546557

547558
if kind == "aggregate" and self._filter_empty_groups and not counts.all():
548559
assert result.ndim != 2

pandas/tests/groupby/aggregate/test_aggregate.py

+21
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,27 @@ def test_aggregate_mixed_types():
663663
tm.assert_frame_equal(result, expected)
664664

665665

666+
@pytest.mark.xfail(reason="Not implemented.")
667+
def test_aggregate_udf_na_extension_type():
668+
# https://github.com/pandas-dev/pandas/pull/31359
669+
# This is currently failing to cast back to Int64Dtype.
670+
# The presence of the NA causes two problems
671+
# 1. NA is not an instance of Int64Dtype.type (numpy.int64)
672+
# 2. The presence of an NA forces object type, so the non-NA values is
673+
# a Python int rather than a NumPy int64. Python ints aren't
674+
# instances of numpy.int64.
675+
def aggfunc(x):
676+
if all(x > 2):
677+
return 1
678+
else:
679+
return pd.NA
680+
681+
df = pd.DataFrame({"A": pd.array([1, 2, 3])})
682+
result = df.groupby([1, 1, 2]).agg(aggfunc)
683+
expected = pd.DataFrame({"A": pd.array([1, pd.NA], dtype="Int64")}, index=[1, 2])
684+
tm.assert_frame_equal(result, expected)
685+
686+
666687
class TestLambdaMangling:
667688
def test_basic(self):
668689
df = pd.DataFrame({"A": [0, 0, 1, 1], "B": [1, 2, 3, 4]})

pandas/tests/groupby/test_categorical.py

+34
Original file line numberDiff line numberDiff line change
@@ -1342,3 +1342,37 @@ def test_series_groupby_categorical_aggregation_getitem():
13421342
result = groups["foo"].agg("mean")
13431343
expected = groups.agg("mean")["foo"]
13441344
tm.assert_series_equal(result, expected)
1345+
1346+
1347+
@pytest.mark.parametrize(
1348+
"func, expected_values",
1349+
[(pd.Series.nunique, [1, 1, 2]), (pd.Series.count, [1, 2, 2])],
1350+
)
1351+
def test_groupby_agg_categorical_columns(func, expected_values):
1352+
# 31256
1353+
df = pd.DataFrame(
1354+
{
1355+
"id": [0, 1, 2, 3, 4],
1356+
"groups": [0, 1, 1, 2, 2],
1357+
"value": pd.Categorical([0, 0, 0, 0, 1]),
1358+
}
1359+
).set_index("id")
1360+
result = df.groupby("groups").agg(func)
1361+
1362+
expected = pd.DataFrame(
1363+
{"value": expected_values}, index=pd.Index([0, 1, 2], name="groups"),
1364+
)
1365+
tm.assert_frame_equal(result, expected)
1366+
1367+
1368+
def test_groupby_agg_non_numeric():
1369+
df = pd.DataFrame(
1370+
{"A": pd.Categorical(["a", "a", "b"], categories=["a", "b", "c"])}
1371+
)
1372+
expected = pd.DataFrame({"A": [2, 1]}, index=[1, 2])
1373+
1374+
result = df.groupby([1, 2, 1]).agg(pd.Series.nunique)
1375+
tm.assert_frame_equal(result, expected)
1376+
1377+
result = df.groupby([1, 2, 1]).nunique()
1378+
tm.assert_frame_equal(result, expected)

pandas/tests/resample/test_datetime_index.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def test_resample_integerarray():
122122

123123
result = ts.resample("3T").mean()
124124
expected = Series(
125-
[1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64"
125+
[1, 4, 7],
126+
index=pd.date_range("1/1/2000", periods=3, freq="3T"),
127+
dtype="float64",
126128
)
127129
tm.assert_series_equal(result, expected)
128130

pandas/tests/resample/test_timedelta.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_resample_categorical_data_with_timedeltaindex():
105105
index=pd.to_timedelta([0, 10], unit="s"),
106106
)
107107
expected = expected.reindex(["Group_obj", "Group"], axis=1)
108-
expected["Group"] = expected["Group_obj"].astype("category")
108+
expected["Group"] = expected["Group_obj"]
109109
tm.assert_frame_equal(result, expected)
110110

111111

0 commit comments

Comments
 (0)