Skip to content

Commit 05a7672

Browse files
jbrockmendelyeshsurya
authored andcommitted
BUG: retain ordered Categorical dtype in SeriesGroupBy aggregations (pandas-dev#41147)
1 parent 9bee5a2 commit 05a7672

File tree

4 files changed

+64
-13
lines changed

4 files changed

+64
-13
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ Groupby/resample/rolling
852852
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` computing wrong result with nullable data types too large to roundtrip when casting to float (:issue:`37493`)
853853
- Bug in :meth:`DataFrame.rolling` returning mean zero for all ``NaN`` window with ``min_periods=0`` if calculation is not numerical stable (:issue:`41053`)
854854
- Bug in :meth:`DataFrame.rolling` returning sum not zero for all ``NaN`` window with ``min_periods=0`` if calculation is not numerical stable (:issue:`41053`)
855+
- Bug in :meth:`SeriesGroupBy.agg` failing to retain ordered :class:`CategoricalDtype` on order-preserving aggregations (:issue:`41147`)
855856
- Bug in :meth:`DataFrameGroupBy.min` and :meth:`DataFrameGroupBy.max` with multiple object-dtype columns and ``numeric_only=False`` incorrectly raising ``ValueError`` (:issue:41111`)
856857

857858
Reshaping

pandas/core/groupby/generic.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,26 @@ def _cython_agg_general(
358358
if numeric_only and not is_numeric:
359359
continue
360360

361-
result = self.grouper._cython_operation(
362-
"aggregate", obj._values, how, axis=0, min_count=min_count
363-
)
361+
objvals = obj._values
362+
363+
if isinstance(objvals, Categorical):
364+
if self.grouper.ngroups > 0:
365+
# without special-casing, we would raise, then in fallback
366+
# would eventually call agg_series but without re-casting
367+
# to Categorical
368+
# equiv: res_values, _ = self.grouper.agg_series(obj, alt)
369+
res_values, _ = self.grouper._aggregate_series_pure_python(obj, alt)
370+
else:
371+
# equiv: res_values = self._python_agg_general(alt)
372+
res_values = self._python_apply_general(alt, self._selected_obj)
373+
374+
result = type(objvals)._from_sequence(res_values, dtype=objvals.dtype)
375+
376+
else:
377+
result = self.grouper._cython_operation(
378+
"aggregate", obj._values, how, axis=0, min_count=min_count
379+
)
380+
364381
assert result.ndim == 1
365382
key = base.OutputKey(label=name, position=idx)
366383
output[key] = result
@@ -1092,13 +1109,7 @@ def cast_agg_result(result: ArrayLike, values: ArrayLike) -> ArrayLike:
10921109
# see if we can cast the values to the desired dtype
10931110
# this may not be the original dtype
10941111

1095-
if isinstance(values, Categorical) and isinstance(result, np.ndarray):
1096-
# If the Categorical op didn't raise, it is dtype-preserving
1097-
# We get here with how="first", "last", "min", "max"
1098-
result = type(values)._from_sequence(result.ravel(), dtype=values.dtype)
1099-
# Note this will have result.dtype == dtype from above
1100-
1101-
elif (
1112+
if (
11021113
not using_array_manager
11031114
and isinstance(result.dtype, np.dtype)
11041115
and result.ndim == 1
@@ -1140,9 +1151,14 @@ def py_fallback(values: ArrayLike) -> ArrayLike:
11401151
# Categoricals. This will done by later self._reindex_output()
11411152
# Doing it here creates an error. See GH#34951
11421153
sgb = get_groupby(obj, self.grouper, observed=True)
1154+
11431155
# Note: bc obj is always a Series here, we can ignore axis and pass
11441156
# `alt` directly instead of `lambda x: alt(x, axis=self.axis)`
1145-
res_ser = sgb.aggregate(alt) # this will go through sgb._python_agg_general
1157+
# use _agg_general bc it will go through _cython_agg_general
1158+
# which will correctly cast Categoricals.
1159+
res_ser = sgb._agg_general(
1160+
numeric_only=False, min_count=min_count, alias=how, npfunc=alt
1161+
)
11461162

11471163
# unwrap Series to get array
11481164
res_values = res_ser._mgr.arrays[0]

pandas/tests/groupby/aggregate/test_aggregate.py

+9
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,11 @@ def test_groupby_single_agg_cat_cols(grp_col_dict, exp_data):
11051105

11061106
expected_df = DataFrame(data=exp_data, index=cat_index)
11071107

1108+
if "cat_ord" in expected_df:
1109+
# ordered categorical columns should be preserved
1110+
dtype = input_df["cat_ord"].dtype
1111+
expected_df["cat_ord"] = expected_df["cat_ord"].astype(dtype)
1112+
11081113
tm.assert_frame_equal(result_df, expected_df)
11091114

11101115

@@ -1149,6 +1154,10 @@ def test_groupby_combined_aggs_cat_cols(grp_col_dict, exp_data):
11491154
multi_index = MultiIndex.from_tuples(tuple(multi_index_list))
11501155

11511156
expected_df = DataFrame(data=exp_data, columns=multi_index, index=cat_index)
1157+
for col in expected_df.columns:
1158+
if isinstance(col, tuple) and "cat_ord" in col:
1159+
# ordered categorical should be preserved
1160+
expected_df[col] = expected_df[col].astype(input_df["cat_ord"].dtype)
11521161

11531162
tm.assert_frame_equal(result_df, expected_df)
11541163

pandas/tests/groupby/test_categorical.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,12 @@ def test_preserve_on_ordered_ops(func, values):
800800
).set_index("payload")
801801
tm.assert_frame_equal(result, expected)
802802

803+
# we should also preserve categorical for SeriesGroupBy
804+
sgb = df.groupby("payload")["col"]
805+
result = getattr(sgb, func)()
806+
expected = expected["col"]
807+
tm.assert_series_equal(result, expected)
808+
803809

804810
def test_categorical_no_compress():
805811
data = Series(np.random.randn(9))
@@ -1494,7 +1500,11 @@ def test_groupy_first_returned_categorical_instead_of_dataframe(func):
14941500
df = DataFrame({"A": [1997], "B": Series(["b"], dtype="category").cat.as_ordered()})
14951501
df_grouped = df.groupby("A")["B"]
14961502
result = getattr(df_grouped, func)()
1497-
expected = Series(["b"], index=Index([1997], name="A"), name="B")
1503+
1504+
# ordered categorical dtype should be preserved
1505+
expected = Series(
1506+
["b"], index=Index([1997], name="A"), name="B", dtype=df["B"].dtype
1507+
)
14981508
tm.assert_series_equal(result, expected)
14991509

15001510

@@ -1561,7 +1571,15 @@ def test_agg_cython_category_not_implemented_fallback():
15611571
df["col_cat"] = df["col_num"].astype("category")
15621572

15631573
result = df.groupby("col_num").col_cat.first()
1564-
expected = Series([1, 2, 3], index=Index([1, 2, 3], name="col_num"), name="col_cat")
1574+
1575+
# ordered categorical dtype should definitely be preserved;
1576+
# this is unordered, so is less-clear case (if anything, it should raise)
1577+
expected = Series(
1578+
[1, 2, 3],
1579+
index=Index([1, 2, 3], name="col_num"),
1580+
name="col_cat",
1581+
dtype=df["col_cat"].dtype,
1582+
)
15651583
tm.assert_series_equal(result, expected)
15661584

15671585
result = df.groupby("col_num").agg({"col_cat": "first"})
@@ -1576,6 +1594,10 @@ def test_aggregate_categorical_lost_index(func: str):
15761594
df = DataFrame({"A": [1997], "B": ds})
15771595
result = df.groupby("A").agg({"B": func})
15781596
expected = DataFrame({"B": ["b"]}, index=Index([1997], name="A"))
1597+
1598+
# ordered categorical dtype should be preserved
1599+
expected["B"] = expected["B"].astype(ds.dtype)
1600+
15791601
tm.assert_frame_equal(result, expected)
15801602

15811603

@@ -1653,6 +1675,9 @@ def test_categorical_transform():
16531675

16541676
expected["status"] = expected["status"].astype(delivery_status_type)
16551677

1678+
# .transform(max) should preserve ordered categoricals
1679+
expected["last_status"] = expected["last_status"].astype(delivery_status_type)
1680+
16561681
tm.assert_frame_equal(result, expected)
16571682

16581683

0 commit comments

Comments
 (0)