Skip to content

Commit 04180b0

Browse files
authored
BUG: groupby.count maintains masked and arrow dtypes (#54129)
1 parent 4c67076 commit 04180b0

File tree

4 files changed

+38
-0
lines changed

4 files changed

+38
-0
lines changed

doc/source/whatsnew/v2.1.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ Groupby/resample/rolling
576576
- Bug in :meth:`GroupBy.var` failing to raise ``TypeError`` when called with datetime64, timedelta64 or :class:`PeriodDtype` values (:issue:`52128`, :issue:`53045`)
577577
- Bug in :meth:`DataFrameGroupby.resample` with ``kind="period"`` raising ``AttributeError`` (:issue:`24103`)
578578
- Bug in :meth:`Resampler.ohlc` with empty object returning a :class:`Series` instead of empty :class:`DataFrame` (:issue:`42902`)
579+
- Bug in :meth:`SeriesGroupBy.count` and :meth:`DataFrameGroupBy.count` where the dtype would be ``np.int64`` for data with :class:`ArrowDtype` or masked dtypes (e.g. ``Int64``) (:issue:`53831`)
579580
- Bug in :meth:`SeriesGroupBy.nth` and :meth:`DataFrameGroupBy.nth` after performing column selection when using ``dropna="any"`` or ``dropna="all"`` would not subset columns (:issue:`53518`)
580581
- Bug in :meth:`SeriesGroupBy.nth` and :meth:`DataFrameGroupBy.nth` raised after performing column selection when using ``dropna="any"`` or ``dropna="all"`` resulted in rows being dropped (:issue:`53518`)
581582
- Bug in :meth:`SeriesGroupBy.sum` and :meth:`DataFrameGroupby.sum` summing ``np.inf + np.inf`` and ``(-np.inf) + (-np.inf)`` to ``np.nan`` (:issue:`53606`)

pandas/core/groupby/groupby.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class providing the base-class of operations.
104104
Categorical,
105105
ExtensionArray,
106106
FloatingArray,
107+
IntegerArray,
107108
)
108109
from pandas.core.base import (
109110
PandasObject,
@@ -2248,6 +2249,12 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
22482249
masked = mask & ~isna(bvalues)
22492250

22502251
counted = lib.count_level_2d(masked, labels=ids, max_bin=ngroups)
2252+
if isinstance(bvalues, BaseMaskedArray):
2253+
return IntegerArray(
2254+
counted[0], mask=np.zeros(counted.shape[1], dtype=np.bool_)
2255+
)
2256+
elif isinstance(bvalues, ArrowExtensionArray):
2257+
return type(bvalues)._from_sequence(counted[0])
22512258
if is_series:
22522259
assert counted.ndim == 2
22532260
assert counted.shape[0] == 1

pandas/tests/extension/test_arrow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,6 +3154,18 @@ def test_to_numpy_temporal(pa_type):
31543154
tm.assert_numpy_array_equal(result, expected)
31553155

31563156

3157+
def test_groupby_count_return_arrow_dtype(data_missing):
3158+
df = pd.DataFrame({"A": [1, 1], "B": data_missing, "C": data_missing})
3159+
result = df.groupby("A").count()
3160+
expected = pd.DataFrame(
3161+
[[1, 1]],
3162+
index=pd.Index([1], name="A"),
3163+
columns=["B", "C"],
3164+
dtype="int64[pyarrow]",
3165+
)
3166+
tm.assert_frame_equal(result, expected)
3167+
3168+
31573169
def test_arrowextensiondtype_dataframe_repr():
31583170
# GH 54062
31593171
df = pd.DataFrame(

pandas/tests/groupby/aggregate/test_cython.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,27 @@ def test_cython_agg_nullable_int(op_name):
343343
# the result is not yet consistently using Int64/Float64 dtype,
344344
# so for now just checking the values by casting to float
345345
result = result.astype("float64")
346+
else:
347+
result = result.astype("int64")
346348
tm.assert_series_equal(result, expected)
347349

348350

351+
@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
352+
def test_count_masked_returns_masked_dtype(dtype):
353+
df = DataFrame(
354+
{
355+
"A": [1, 1],
356+
"B": pd.array([1, pd.NA], dtype=dtype),
357+
"C": pd.array([1, 1], dtype=dtype),
358+
}
359+
)
360+
result = df.groupby("A").count()
361+
expected = DataFrame(
362+
[[1, 2]], index=Index([1], name="A"), columns=["B", "C"], dtype="Int64"
363+
)
364+
tm.assert_frame_equal(result, expected)
365+
366+
349367
@pytest.mark.parametrize("with_na", [True, False])
350368
@pytest.mark.parametrize(
351369
"op_name, action",

0 commit comments

Comments
 (0)