Skip to content

Commit bfa2b9e

Browse files
BUG: regression when applying groupby aggregation on categorical columns (#31359) (#31428)
Co-authored-by: Kaiqi Dong <[email protected]>
1 parent 05a0b63 commit bfa2b9e

File tree

7 files changed

+138
-5
lines changed

7 files changed

+138
-5
lines changed

doc/source/whatsnew/v1.0.0.rst

+48
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,54 @@ consistent with the behaviour of :class:`DataFrame` and :class:`Index`.
626626
DeprecationWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
627627
Series([], dtype: float64)
628628
629+
Result dtype inference changes for resample operations
630+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
631+
632+
The rules for the result dtype in :meth:`DataFrame.resample` aggregations have changed for extension types (:issue:`31359`).
633+
Previously, pandas would attempt to convert the result back to the original dtype, falling back to the usual
634+
inference rules if that was not possible. Now, pandas will only return a result of the original dtype if the
635+
scalar values in the result are instances of the extension dtype's scalar type.
636+
637+
.. ipython:: python
638+
639+
df = pd.DataFrame({"A": ['a', 'b']}, dtype='category',
640+
index=pd.date_range('2000', periods=2))
641+
df
642+
643+
644+
*pandas 0.25.x*
645+
646+
.. code-block:: python
647+
648+
>>> df.resample("2D").agg(lambda x: 'a').A.dtype
649+
CategoricalDtype(categories=['a', 'b'], ordered=False)
650+
651+
*pandas 1.0.0*
652+
653+
.. ipython:: python
654+
655+
df.resample("2D").agg(lambda x: 'a').A.dtype
656+
657+
This fixes an inconsistency between ``resample`` and ``groupby``.
658+
This also fixes a potential bug, where the **values** of the result might change
659+
depending on how the results are cast back to the original dtype.
660+
661+
*pandas 0.25.x*
662+
663+
.. code-block:: python
664+
665+
>>> df.resample("2D").agg(lambda x: 'c')
666+
667+
A
668+
0 NaN
669+
670+
*pandas 1.0.0*
671+
672+
.. ipython:: python
673+
674+
df.resample("2D").agg(lambda x: 'c')
675+
676+
629677
.. _whatsnew_100.api_breaking.python:
630678

631679
Increased minimum version for Python

pandas/core/groupby/groupby.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,10 @@ def _try_cast(self, result, obj, numeric_only: bool = False):
813813
# datetime64tz is handled correctly in agg_series,
814814
# so is excluded here.
815815

816-
# return the same type (Series) as our caller
817-
cls = dtype.construct_array_type()
818-
result = try_cast_to_ea(cls, result, dtype=dtype)
816+
if len(result) and isinstance(result[0], dtype.type):
817+
cls = dtype.construct_array_type()
818+
result = try_cast_to_ea(cls, result, dtype=dtype)
819+
819820
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
820821
result = maybe_downcast_to_dtype(result, dtype)
821822

pandas/core/groupby/ops.py

+11
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,17 @@ def _cython_operation(
543543
if mask.any():
544544
result = result.astype("float64")
545545
result[mask] = np.nan
546+
elif (
547+
how == "add"
548+
and is_integer_dtype(orig_values.dtype)
549+
and is_extension_array_dtype(orig_values.dtype)
550+
):
551+
# We need this to ensure that Series[Int64Dtype].resample().sum()
552+
# remains int64 dtype.
553+
# Two options for avoiding this special case
554+
# 1. mask-aware ops and avoid casting to float with NaN above
555+
# 2. specify the result dtype when calling this method
556+
result = result.astype("int64")
546557

547558
if kind == "aggregate" and self._filter_empty_groups and not counts.all():
548559
assert result.ndim != 2

pandas/tests/groupby/aggregate/test_aggregate.py

+37
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,43 @@ def test_lambda_named_agg(func):
648648
tm.assert_frame_equal(result, expected)
649649

650650

651+
def test_aggregate_mixed_types():
652+
# GH 16916
653+
df = pd.DataFrame(
654+
data=np.array([0] * 9).reshape(3, 3), columns=list("XYZ"), index=list("abc")
655+
)
656+
df["grouping"] = ["group 1", "group 1", 2]
657+
result = df.groupby("grouping").aggregate(lambda x: x.tolist())
658+
expected_data = [[[0], [0], [0]], [[0, 0], [0, 0], [0, 0]]]
659+
expected = pd.DataFrame(
660+
expected_data,
661+
index=Index([2, "group 1"], dtype="object", name="grouping"),
662+
columns=Index(["X", "Y", "Z"], dtype="object"),
663+
)
664+
tm.assert_frame_equal(result, expected)
665+
666+
667+
@pytest.mark.xfail(reason="Not implemented.")
668+
def test_aggregate_udf_na_extension_type():
669+
# https://github.com/pandas-dev/pandas/pull/31359
670+
# This is currently failing to cast back to Int64Dtype.
671+
# The presence of the NA causes two problems
672+
# 1. NA is not an instance of Int64Dtype.type (numpy.int64)
673+
# 2. The presence of an NA forces object type, so the non-NA values is
674+
# a Python int rather than a NumPy int64. Python ints aren't
675+
# instances of numpy.int64.
676+
def aggfunc(x):
677+
if all(x > 2):
678+
return 1
679+
else:
680+
return pd.NA
681+
682+
df = pd.DataFrame({"A": pd.array([1, 2, 3])})
683+
result = df.groupby([1, 1, 2]).agg(aggfunc)
684+
expected = pd.DataFrame({"A": pd.array([1, pd.NA], dtype="Int64")}, index=[1, 2])
685+
tm.assert_frame_equal(result, expected)
686+
687+
651688
class TestLambdaMangling:
652689
def test_maybe_mangle_lambdas_passthrough(self):
653690
assert _maybe_mangle_lambdas("mean") == "mean"

pandas/tests/groupby/test_categorical.py

+34
Original file line numberDiff line numberDiff line change
@@ -1342,3 +1342,37 @@ def test_series_groupby_categorical_aggregation_getitem():
13421342
result = groups["foo"].agg("mean")
13431343
expected = groups.agg("mean")["foo"]
13441344
tm.assert_series_equal(result, expected)
1345+
1346+
1347+
@pytest.mark.parametrize(
1348+
"func, expected_values",
1349+
[(pd.Series.nunique, [1, 1, 2]), (pd.Series.count, [1, 2, 2])],
1350+
)
1351+
def test_groupby_agg_categorical_columns(func, expected_values):
1352+
# 31256
1353+
df = pd.DataFrame(
1354+
{
1355+
"id": [0, 1, 2, 3, 4],
1356+
"groups": [0, 1, 1, 2, 2],
1357+
"value": pd.Categorical([0, 0, 0, 0, 1]),
1358+
}
1359+
).set_index("id")
1360+
result = df.groupby("groups").agg(func)
1361+
1362+
expected = pd.DataFrame(
1363+
{"value": expected_values}, index=pd.Index([0, 1, 2], name="groups"),
1364+
)
1365+
tm.assert_frame_equal(result, expected)
1366+
1367+
1368+
def test_groupby_agg_non_numeric():
1369+
df = pd.DataFrame(
1370+
{"A": pd.Categorical(["a", "a", "b"], categories=["a", "b", "c"])}
1371+
)
1372+
expected = pd.DataFrame({"A": [2, 1]}, index=[1, 2])
1373+
1374+
result = df.groupby([1, 2, 1]).agg(pd.Series.nunique)
1375+
tm.assert_frame_equal(result, expected)
1376+
1377+
result = df.groupby([1, 2, 1]).nunique()
1378+
tm.assert_frame_equal(result, expected)

pandas/tests/resample/test_datetime_index.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def test_resample_integerarray():
122122

123123
result = ts.resample("3T").mean()
124124
expected = Series(
125-
[1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64"
125+
[1, 4, 7],
126+
index=pd.date_range("1/1/2000", periods=3, freq="3T"),
127+
dtype="float64",
126128
)
127129
tm.assert_series_equal(result, expected)
128130

pandas/tests/resample/test_timedelta.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_resample_categorical_data_with_timedeltaindex():
105105
index=pd.to_timedelta([0, 10], unit="s"),
106106
)
107107
expected = expected.reindex(["Group_obj", "Group"], axis=1)
108-
expected["Group"] = expected["Group_obj"].astype("category")
108+
expected["Group"] = expected["Group_obj"]
109109
tm.assert_frame_equal(result, expected)
110110

111111

0 commit comments

Comments
 (0)