Skip to content

Commit d1a8378

Browse files
committed
Fix convert_dtype issue.
Signed-off-by: Liang Yan <[email protected]>
1 parent 96fd5de commit d1a8378

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

pandas/core/dtypes/cast.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -458,27 +458,23 @@ def maybe_cast_pointwise_result(
458458
"""
459459

460460
if isinstance(dtype, ExtensionDtype):
461-
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)):
461+
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
462462
# TODO: avoid this special-casing
463463
# We have to special case categorical so as not to upcast
464464
# things like counts back to categorical
465-
466-
cls = dtype.construct_array_type()
467-
if same_dtype:
468-
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
469-
else:
470-
result = _maybe_cast_to_extension_array(cls, result)
471-
elif isinstance(dtype, ArrowDtype):
472-
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
473-
if isinstance(pyarrow_type, ExtensionDtype):
474-
cls = pyarrow_type.construct_array_type()
475-
result = _maybe_cast_to_extension_array(cls, result)
465+
if isinstance(dtype, ArrowDtype):
466+
pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow")
476467
else:
468+
pyarrow_type = np.dtype("object")
469+
if not isinstance(pyarrow_type, ExtensionDtype):
477470
cls = dtype.construct_array_type()
478471
if same_dtype:
479472
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
480473
else:
481474
result = _maybe_cast_to_extension_array(cls, result)
475+
else:
476+
cls = pyarrow_type.construct_array_type()
477+
result = _maybe_cast_to_extension_array(cls, result)
482478
elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
483479
result = maybe_downcast_to_dtype(result, dtype)
484480

pandas/tests/groupby/aggregate/test_aggregate.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import functools
66
from functools import partial
77
import re
8-
import typing
98

109
import numpy as np
1110
import pytest
@@ -25,6 +24,7 @@
2524
)
2625
import pandas._testing as tm
2726
from pandas.core.groupby.grouper import Grouping
27+
from pandas.tests.arrays.string_.test_string_arrow import skip_if_no_pyarrow
2828

2929

3030
def test_groupby_agg_no_extra_calls():
@@ -1610,9 +1610,7 @@ def test_agg_with_as_index_false_with_list():
16101610
tm.assert_frame_equal(result, expected)
16111611

16121612

1613-
# @pytest.mark.skipif(
1614-
# not typing.TYPE_CHECKING, reason="let pyarrow to be imported in dtypes.py"
1615-
# )
1613+
@skip_if_no_pyarrow
16161614
def test_agg_arrow_type():
16171615
df = DataFrame.from_dict(
16181616
{

0 commit comments

Comments
 (0)