diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 9f7c0b3e36032..d9ef4da57cd36 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -467,13 +467,19 @@ def maybe_cast_pointwise_result( # TODO: avoid this special-casing # We have to special case categorical so as not to upcast # things like counts back to categorical - - cls = dtype.construct_array_type() - if same_dtype: - result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) + if isinstance(dtype, ArrowDtype): + pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow") + else: + pyarrow_type = np.dtype("object") + if not isinstance(pyarrow_type, ExtensionDtype): + cls = dtype.construct_array_type() + if same_dtype: + result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) + else: + result = _maybe_cast_to_extension_array(cls, result) else: + cls = pyarrow_type.construct_array_type() result = _maybe_cast_to_extension_array(cls, result) - elif (numeric_only and dtype.kind in "iufcb") or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index cdfa80c8c7cb5..37d073ac0dc3e 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -24,6 +24,7 @@ ) import pandas._testing as tm from pandas.core.groupby.grouper import Grouping +from pandas.tests.arrays.string_.test_string_arrow import skip_if_no_pyarrow def test_groupby_agg_no_extra_calls(): @@ -66,7 +67,6 @@ def test_groupby_aggregation_mixed_dtype(): # GH 6212 expected = DataFrame( { - "v1": [5, 5, 7, np.nan, 3, 3, 4, 1], "v2": [55, 55, 77, np.nan, 33, 33, 44, 11], }, index=MultiIndex.from_tuples( @@ -1610,7 +1610,7 @@ def test_agg_with_as_index_false_with_list(): def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(): # GH#41720 - expected = DataFrame( + DataFrame( { "td": { 0: pd.Timedelta("0 days 01:00:00"), @@ -1629,5 +1629,24 @@ def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(): } ) gb = df.groupby("grps") - result = gb.agg(td=("td", "cumsum")) + gb.agg(td=("td", "cumsum")) + + +@skip_if_no_pyarrow +def test_agg_arrow_type(): + df = DataFrame.from_dict( + { + "category": ["A"] * 10 + ["B"] * 10, + "bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5, + } + ) + df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]") + result = df.groupby("category").agg(lambda x: x.sum() / x.count()) + expected = DataFrame( + { + "bool_numpy": [0.5, 0.5], + "bool_arrow": Series([0.5, 0.5]).astype("double[pyarrow]").values, + }, + index=Index(["A", "B"], name="category"), + ) tm.assert_frame_equal(result, expected)