diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index a2c3bf1453c2e..f4dbf9e08ac18 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -559,6 +559,7 @@ Groupby/resample/rolling - Bug in :meth:`DataFrameGroupBy.apply` raising a ``TypeError`` when selecting multiple columns and providing a function that returns ``np.ndarray`` results (:issue:`18930`) - Bug in :meth:`GroupBy.groups` with a datetime key in conjunction with another key produced incorrect number of group keys (:issue:`51158`) - Bug in :meth:`GroupBy.quantile` may implicitly sort the result index with ``sort=False`` (:issue:`53009`) +- Bug in :meth:`SeriesGroupBy.size` where the dtype would be ``np.int64`` for data with :class:`ArrowDtype` or masked dtypes (e.g. ``Int64``) (:issue:`53831`) - Bug in :meth:`GroupBy.var` failing to raise ``TypeError`` when called with datetime64, timedelta64 or :class:`PeriodDtype` values (:issue:`52128`, :issue:`53045`) - Bug in :meth:`DataFrameGroupby.resample` with ``kind="period"`` raising ``AttributeError`` (:issue:`24103`) - Bug in :meth:`Resampler.ohlc` with empty object returning a :class:`Series` instead of empty :class:`DataFrame` (:issue:`42902`) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 9d87a28093371..64d874a31c428 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -99,6 +99,7 @@ class providing the base-class of operations. from pandas.core._numba import executor from pandas.core.apply import warn_alias_replacement from pandas.core.arrays import ( + ArrowExtensionArray, BaseMaskedArray, Categorical, ExtensionArray, @@ -2930,6 +2931,13 @@ def size(self) -> DataFrame | Series: Freq: MS, dtype: int64 """ result = self.grouper.size() + dtype_backend: None | Literal["pyarrow", "numpy_nullable"] = None + if isinstance(self.obj, Series): + if isinstance(self.obj.array, ArrowExtensionArray): + dtype_backend = "pyarrow" + elif isinstance(self.obj.array, BaseMaskedArray): + dtype_backend = "numpy_nullable" + # TODO: For DataFrames what if columns are mixed arrow/numpy/masked? # GH28330 preserve subclassed Series/DataFrames through calls if isinstance(self.obj, Series): @@ -2937,6 +2945,15 @@ def size(self) -> DataFrame | Series: else: result = self._obj_1d_constructor(result) + if dtype_backend is not None: + result = result.convert_dtypes( + infer_objects=False, + convert_string=False, + convert_boolean=False, + convert_floating=False, + dtype_backend=dtype_backend, + ) + with com.temp_setattr(self, "as_index", True): # size already has the desired behavior in GH#49519, but this makes the # as_index=False path of _reindex_output fail on categorical groupers. diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 2a6b57a365a11..b6fa4fbcaa92c 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -3122,6 +3122,14 @@ def test_iter_temporal(pa_type): assert result == expected +def test_groupby_series_size_returns_pa_int(data): + # GH 54132 + ser = pd.Series(data[:3], index=["a", "a", "b"]) + result = ser.groupby(level=0).size() + expected = pd.Series([2, 1], dtype="int64[pyarrow]", index=["a", "b"]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( "pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES ) diff --git a/pandas/tests/groupby/test_size.py b/pandas/tests/groupby/test_size.py index b96fe41c26c3e..e7598ec34fa15 100644 --- a/pandas/tests/groupby/test_size.py +++ b/pandas/tests/groupby/test_size.py @@ -95,3 +95,12 @@ def test_size_on_categorical(as_index): expected = expected.set_index(["A", "B"])["size"].rename(None) tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"]) +def test_size_series_masked_type_returns_Int64(dtype): + # GH 54132 + ser = Series([1, 1, 1], index=["a", "a", "b"], dtype=dtype) + result = ser.groupby(level=0).size() + expected = Series([2, 1], dtype="Int64", index=["a", "b"]) + tm.assert_series_equal(result, expected)