Skip to content

Commit 3ba1a4b

Browse files
committed
Fix convert_dtype issue.
Signed-off-by: Liang Yan <[email protected]>
1 parent 0a18797 commit 3ba1a4b

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

pandas/core/dtypes/cast.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -461,27 +461,23 @@ def maybe_cast_pointwise_result(
461461
"""
462462

463463
if isinstance(dtype, ExtensionDtype):
464-
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)):
464+
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
465465
# TODO: avoid this special-casing
466466
# We have to special case categorical so as not to upcast
467467
# things like counts back to categorical
468-
469-
cls = dtype.construct_array_type()
470-
if same_dtype:
471-
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
472-
else:
473-
result = _maybe_cast_to_extension_array(cls, result)
474-
elif isinstance(dtype, ArrowDtype):
475-
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
476-
if isinstance(pyarrow_type, ExtensionDtype):
477-
cls = pyarrow_type.construct_array_type()
478-
result = _maybe_cast_to_extension_array(cls, result)
468+
if isinstance(dtype, ArrowDtype):
469+
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
479470
else:
471+
pyarrow_type = np.dtype("object")
472+
if not isinstance(pyarrow_type, ExtensionDtype):
480473
cls = dtype.construct_array_type()
481474
if same_dtype:
482475
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
483476
else:
484477
result = _maybe_cast_to_extension_array(cls, result)
478+
else:
479+
cls = pyarrow_type.construct_array_type()
480+
result = _maybe_cast_to_extension_array(cls, result)
485481
elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
486482
result = maybe_downcast_to_dtype(result, dtype)
487483

pandas/tests/groupby/aggregate/test_aggregate.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import functools
66
from functools import partial
77
import re
8-
import typing
98

109
import numpy as np
1110
import pytest
@@ -25,6 +24,7 @@
2524
)
2625
import pandas._testing as tm
2726
from pandas.core.groupby.grouper import Grouping
27+
from pandas.tests.arrays.string_.test_string_arrow import skip_if_no_pyarrow
2828

2929

3030
def test_groupby_agg_no_extra_calls():
@@ -1610,9 +1610,7 @@ def test_agg_with_as_index_false_with_list():
16101610
tm.assert_frame_equal(result, expected)
16111611

16121612

1613-
# @pytest.mark.skipif(
1614-
# not typing.TYPE_CHECKING, reason="let pyarrow to be imported in dtypes.py"
1615-
# )
1613+
@skip_if_no_pyarrow
16161614
def test_agg_arrow_type():
16171615
df = DataFrame.from_dict(
16181616
{

0 commit comments

Comments
 (0)