Skip to content

Commit d979c10

Browse files
committed
BUG: Aggregation on arrow array return same type.
Signed-off-by: Liang Yan <[email protected]>
1 parent f0d3301 commit d979c10

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

pandas/core/dtypes/cast.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -458,17 +458,26 @@ def maybe_cast_pointwise_result(
458458
"""
459459

460460
if isinstance(dtype, ExtensionDtype):
461-
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
461+
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)):
462462
# TODO: avoid this special-casing
463463
# We have to special case categorical so as not to upcast
464464
# things like counts back to categorical
465-
466465
cls = dtype.construct_array_type()
467466
if same_dtype:
468467
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
469468
else:
470469
result = _maybe_cast_to_extension_array(cls, result)
471-
470+
elif isinstance(dtype, ArrowDtype):
471+
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
472+
if isinstance(pyarrow_type, ExtensionDtype):
473+
cls = pyarrow_type.construct_array_type()
474+
result = _maybe_cast_to_extension_array(cls, result)
475+
else:
476+
cls = dtype.construct_array_type()
477+
if same_dtype:
478+
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
479+
else:
480+
result = _maybe_cast_to_extension_array(cls, result)
472481
elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
473482
result = maybe_downcast_to_dtype(result, dtype)
474483

pandas/tests/groupby/aggregate/test_aggregate.py

+23
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import functools
66
from functools import partial
77
import re
8+
import typing
89

910
import numpy as np
1011
import pytest
@@ -1603,3 +1604,25 @@ def test_agg_with_as_index_false_with_list():
16031604
columns=MultiIndex.from_tuples([("a1", ""), ("a2", ""), ("b", "sum")]),
16041605
)
16051606
tm.assert_frame_equal(result, expected)
1607+
1608+
1609+
@pytest.mark.skipif(
1610+
not typing.TYPE_CHECKING, reason="TYPE_CHECKING must be True to import pyarrow"
1611+
)
1612+
def test_agg_arrow_type():
1613+
df = DataFrame.from_dict(
1614+
{
1615+
"category": ["A"] * 10 + ["B"] * 10,
1616+
"bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5,
1617+
}
1618+
)
1619+
df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]")
1620+
result = df.groupby("category").agg(lambda x: x.sum() / x.count())
1621+
expected = DataFrame(
1622+
{
1623+
"bool_numpy": [0.5, 0.5],
1624+
"bool_arrow": Series([0.5, 0.5]).astype("double[pyarrow]").values,
1625+
},
1626+
index=Index(["A", "B"], name="category"),
1627+
)
1628+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)