Skip to content

Commit daec2e7

Browse files
authored
ENH: retain masked EA dtypes in groupby with as_index=False (#41373)
1 parent 4c9ef1b commit daec2e7

File tree

8 files changed

+31
-17
lines changed

8 files changed

+31
-17
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ enhancement2
2929

3030
Other enhancements
3131
^^^^^^^^^^^^^^^^^^
32+
- :class:`DataFrameGroupBy` operations with ``as_index=False`` now correctly retain ``ExtensionDtype`` dtypes for columns being grouped on (:issue:`41373`)
3233
- Add support for assigning values to ``by`` argument in :meth:`DataFrame.plot.hist` and :meth:`DataFrame.plot.box` (:issue:`15079`)
3334
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
3435
- Additional options added to :meth:`.Styler.bar` to control alignment and display, with keyword only arguments (:issue:`26070`, :issue:`36419`)

pandas/core/groupby/generic.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
10331033
self._insert_inaxis_grouper_inplace(result)
10341034
result.index = Index(range(len(result)))
10351035

1036-
return result._convert(datetime=True)
1036+
return result
10371037

10381038
agg = aggregate
10391039

@@ -1684,6 +1684,8 @@ def _wrap_agged_manager(self, mgr: Manager2D) -> DataFrame:
16841684
if self.axis == 1:
16851685
result = result.T
16861686

1687+
# Note: we only need to pass datetime=True in order to get numeric
1688+
# values converted
16871689
return self._reindex_output(result)._convert(datetime=True)
16881690

16891691
def _iterate_column_groupbys(self, obj: FrameOrSeries):

pandas/core/groupby/grouper.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -619,11 +619,20 @@ def group_arraylike(self) -> ArrayLike:
619619
Analogous to result_index, but holding an ArrayLike to ensure
620620
we can can retain ExtensionDtypes.
621621
"""
622+
if self._group_index is not None:
623+
# _group_index is set in __init__ for MultiIndex cases
624+
return self._group_index._values
625+
626+
elif self._all_grouper is not None:
627+
# retain dtype for categories, including unobserved ones
628+
return self.result_index._values
629+
622630
return self._codes_and_uniques[1]
623631

624632
@cache_readonly
625633
def result_index(self) -> Index:
626-
# TODO: what's the difference between result_index vs group_index?
634+
# result_index retains dtype for categories, including unobserved ones,
635+
# which group_index does not
627636
if self._all_grouper is not None:
628637
group_idx = self.group_index
629638
assert isinstance(group_idx, CategoricalIndex)
@@ -635,7 +644,8 @@ def group_index(self) -> Index:
635644
if self._group_index is not None:
636645
# _group_index is set in __init__ for MultiIndex cases
637646
return self._group_index
638-
uniques = self.group_arraylike
647+
648+
uniques = self._codes_and_uniques[1]
639649
return Index(uniques, name=self.name)
640650

641651
@cache_readonly

pandas/core/groupby/ops.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,7 @@ def result_arraylike(self) -> ArrayLike:
885885
if len(self.groupings) == 1:
886886
return self.groupings[0].group_arraylike
887887

888+
# result_index is MultiIndex
888889
return self.result_index._values
889890

890891
@cache_readonly
@@ -903,12 +904,12 @@ def get_group_levels(self) -> list[ArrayLike]:
903904
# Note: only called from _insert_inaxis_grouper_inplace, which
904905
# is only called for BaseGrouper, never for BinGrouper
905906
if len(self.groupings) == 1:
906-
return [self.groupings[0].result_index]
907+
return [self.groupings[0].group_arraylike]
907908

908909
name_list = []
909910
for ping, codes in zip(self.groupings, self.reconstructed_codes):
910911
codes = ensure_platform_int(codes)
911-
levels = ping.result_index.take(codes)
912+
levels = ping.group_arraylike.take(codes)
912913

913914
name_list.append(levels)
914915

pandas/tests/extension/base/groupby.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ def test_grouping_grouper(self, data_for_grouping):
2222
def test_groupby_extension_agg(self, as_index, data_for_grouping):
2323
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})
2424
result = df.groupby("B", as_index=as_index).A.mean()
25-
_, index = pd.factorize(data_for_grouping, sort=True)
25+
_, uniques = pd.factorize(data_for_grouping, sort=True)
2626

27-
index = pd.Index(index, name="B")
28-
expected = pd.Series([3.0, 1.0, 4.0], index=index, name="A")
2927
if as_index:
28+
index = pd.Index(uniques, name="B")
29+
expected = pd.Series([3.0, 1.0, 4.0], index=index, name="A")
3030
self.assert_series_equal(result, expected)
3131
else:
32-
expected = expected.reset_index()
32+
expected = pd.DataFrame({"B": uniques, "A": [3.0, 1.0, 4.0]})
3333
self.assert_frame_equal(result, expected)
3434

3535
def test_groupby_agg_extension(self, data_for_grouping):

pandas/tests/extension/json/test_json.py

-4
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,6 @@ def test_groupby_extension_apply(self):
312312
we'll be able to dispatch unique.
313313
"""
314314

315-
@pytest.mark.parametrize("as_index", [True, False])
316-
def test_groupby_extension_agg(self, as_index, data_for_grouping):
317-
super().test_groupby_extension_agg(as_index, data_for_grouping)
318-
319315
@pytest.mark.xfail(reason="GH#39098: Converts agg result to object")
320316
def test_groupby_agg_extension(self, data_for_grouping):
321317
super().test_groupby_agg_extension(data_for_grouping)

pandas/tests/extension/test_boolean.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -269,14 +269,14 @@ def test_grouping_grouper(self, data_for_grouping):
269269
def test_groupby_extension_agg(self, as_index, data_for_grouping):
270270
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
271271
result = df.groupby("B", as_index=as_index).A.mean()
272-
_, index = pd.factorize(data_for_grouping, sort=True)
272+
_, uniques = pd.factorize(data_for_grouping, sort=True)
273273

274-
index = pd.Index(index, name="B")
275-
expected = pd.Series([3.0, 1.0], index=index, name="A")
276274
if as_index:
275+
index = pd.Index(uniques, name="B")
276+
expected = pd.Series([3.0, 1.0], index=index, name="A")
277277
self.assert_series_equal(result, expected)
278278
else:
279-
expected = expected.reset_index()
279+
expected = pd.DataFrame({"B": uniques, "A": [3.0, 1.0]})
280280
self.assert_frame_equal(result, expected)
281281

282282
def test_groupby_agg_extension(self, data_for_grouping):

pandas/tests/groupby/test_groupby.py

+4
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,10 @@ def test_ops_not_as_index(reduction_func):
717717
expected = expected.rename("size")
718718
expected = expected.reset_index()
719719

720+
if reduction_func != "size":
721+
# 32 bit compat -> groupby preserves dtype whereas reset_index casts to int64
722+
expected["a"] = expected["a"].astype(df["a"].dtype)
723+
720724
g = df.groupby("a", as_index=False)
721725

722726
result = getattr(g, reduction_func)()

0 commit comments

Comments
 (0)