Skip to content

Commit 6028670

Browse files
committed
BUG: Aggregation on arrow array return same type.
Signed-off-by: Liang Yan <[email protected]>
1 parent d06f2d3 commit 6028670

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

pandas/core/groupby/ops.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
if TYPE_CHECKING:
7474
from pandas.core.generic import NDFrame
7575

76+
from pandas.arrays import ArrowExtensionArray
77+
7678

7779
def check_result_array(obj, dtype):
7880
# Our operation is supposed to be an aggregation/reduction. If
@@ -837,7 +839,11 @@ def agg_series(
837839
# test_groupby_empty_with_category gets here with self.ngroups == 0
838840
# and len(obj) > 0
839841

840-
if len(obj) > 0 and not isinstance(obj._values, np.ndarray):
842+
if (
843+
len(obj) > 0
844+
and not isinstance(obj._values, np.ndarray)
845+
and not isinstance(obj._values, ArrowExtensionArray)
846+
):
841847
# we can preserve a little bit more aggressively with EA dtype
842848
# because maybe_cast_pointwise_result will do a try/except
843849
# with _from_sequence. NB we are assuming here that _from_sequence

pandas/tests/groupby/aggregate/test_aggregate.py

+21
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import functools
66
from functools import partial
77
import re
8+
import typing
89

910
import numpy as np
1011
import pytest
@@ -1603,3 +1604,23 @@ def test_agg_with_as_index_false_with_list():
16031604
columns=MultiIndex.from_tuples([("a1", ""), ("a2", ""), ("b", "sum")]),
16041605
)
16051606
tm.assert_frame_equal(result, expected)
1607+
1608+
1609+
@pytest.mark.skipif(
1610+
not typing.TYPE_CHECKING, reason="TYPE_CHECKING must be True to import pyarrow"
1611+
)
1612+
def test_agg_arrow_type():
1613+
df = DataFrame.from_dict(
1614+
{
1615+
"category": ["A"] * 10 + ["B"] * 10,
1616+
"bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5,
1617+
}
1618+
)
1619+
df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]")
1620+
result = df.groupby("category").agg(lambda x: x.sum() / x.count())
1621+
expected = DataFrame(
1622+
[[0.5, 0.5], [0.5, 0.5]],
1623+
columns=["bool_numpy", "bool_arrow"],
1624+
index=Index(["A", "B"], name="category"),
1625+
)
1626+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)