Skip to content

Commit 2451d2e

Browse files
Backport PR #54752 on branch 2.1.x (REGR: groupby.count returning string dtype instead of numeric for string input) (#54762)
Backport PR #54752: REGR: groupby.count returning string dtype instead of numeric for string input Co-authored-by: Patrick Hoefler <[email protected]>
1 parent b761844 commit 2451d2e

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

pandas/core/groupby/groupby.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class providing the base-class of operations.
107107
IntegerArray,
108108
SparseArray,
109109
)
110+
from pandas.core.arrays.string_ import StringDtype
110111
from pandas.core.base import (
111112
PandasObject,
112113
SelectionMixin,
@@ -2261,7 +2262,9 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
22612262
return IntegerArray(
22622263
counted[0], mask=np.zeros(counted.shape[1], dtype=np.bool_)
22632264
)
2264-
elif isinstance(bvalues, ArrowExtensionArray):
2265+
elif isinstance(bvalues, ArrowExtensionArray) and not isinstance(
2266+
bvalues.dtype, StringDtype
2267+
):
22652268
return type(bvalues)._from_sequence(counted[0])
22662269
if is_series:
22672270
assert counted.ndim == 2

pandas/tests/groupby/test_counting.py

+11
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,14 @@ def __eq__(self, other):
379379
result = df.groupby("grp").count()
380380
expected = DataFrame({"a": [2, 2]}, index=Index(list("ab"), name="grp"))
381381
tm.assert_frame_equal(result, expected)
382+
383+
384+
def test_count_arrow_string_array(any_string_dtype):
385+
# GH#54751
386+
pytest.importorskip("pyarrow")
387+
df = DataFrame(
388+
{"a": [1, 2, 3], "b": Series(["a", "b", "a"], dtype=any_string_dtype)}
389+
)
390+
result = df.groupby("a").count()
391+
expected = DataFrame({"b": 1}, index=Index([1, 2, 3], name="a"))
392+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)