Skip to content

Commit a5fe7ac

Browse files
committed
Fix convert_dtype issue.
Signed-off-by: Liang Yan <[email protected]>
1 parent d8843e5 commit a5fe7ac

File tree

2 files changed

+9
-33
lines changed

2 files changed

+9
-33
lines changed

pandas/core/dtypes/cast.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -463,27 +463,23 @@ def maybe_cast_pointwise_result(
463463
"""
464464

465465
if isinstance(dtype, ExtensionDtype):
466-
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)):
466+
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
467467
# TODO: avoid this special-casing
468468
# We have to special case categorical so as not to upcast
469469
# things like counts back to categorical
470-
471-
cls = dtype.construct_array_type()
472-
if same_dtype:
473-
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
474-
else:
475-
result = _maybe_cast_to_extension_array(cls, result)
476-
elif isinstance(dtype, ArrowDtype):
477-
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
478-
if isinstance(pyarrow_type, ExtensionDtype):
479-
cls = pyarrow_type.construct_array_type()
480-
result = _maybe_cast_to_extension_array(cls, result)
470+
if isinstance(dtype, ArrowDtype):
471+
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
481472
else:
473+
pyarrow_type = np.dtype("object")
474+
if not isinstance(pyarrow_type, ExtensionDtype):
482475
cls = dtype.construct_array_type()
483476
if same_dtype:
484477
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
485478
else:
486479
result = _maybe_cast_to_extension_array(cls, result)
480+
else:
481+
cls = pyarrow_type.construct_array_type()
482+
result = _maybe_cast_to_extension_array(cls, result)
487483
elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
488484
result = maybe_downcast_to_dtype(result, dtype)
489485

pandas/tests/groupby/aggregate/test_aggregate.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import functools
66
from functools import partial
77
import re
8-
import typing
98

109
import numpy as np
1110
import pytest
@@ -25,6 +24,7 @@
2524
)
2625
import pandas._testing as tm
2726
from pandas.core.groupby.grouper import Grouping
27+
from pandas.tests.arrays.string_.test_string_arrow import skip_if_no_pyarrow
2828

2929

3030
def test_groupby_agg_no_extra_calls():
@@ -1631,23 +1631,3 @@ def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation():
16311631
)
16321632
gb = df.groupby("grps")
16331633
result = gb.agg(td=("td", "cumsum"))
1634-
1635-
1636-
@skip_if_no_pyarrow
1637-
def test_agg_arrow_type():
1638-
df = DataFrame.from_dict(
1639-
{
1640-
"category": ["A"] * 10 + ["B"] * 10,
1641-
"bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5,
1642-
}
1643-
)
1644-
df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]")
1645-
result = df.groupby("category").agg(lambda x: x.sum() / x.count())
1646-
expected = DataFrame(
1647-
{
1648-
"bool_numpy": [0.5, 0.5],
1649-
"bool_arrow": Series([0.5, 0.5]).astype("double[pyarrow]").values,
1650-
},
1651-
index=Index(["A", "B"], name="category"),
1652-
)
1653-
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)