Skip to content

Commit fd9a9ea

Browse files
authored
BUG: is_string_dtype returns True for ArrowDtype(pa.string()) (#50963)
1 parent dba96f9 commit fd9a9ea

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

doc/source/whatsnew/v2.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,7 @@ Conversion
10211021

10221022
Strings
10231023
^^^^^^^
1024-
- Bug in :func:`pandas.api.dtypes.is_string_dtype` that would not return ``True`` for :class:`StringDtype` (:issue:`15585`)
1024+
- Bug in :func:`pandas.api.dtypes.is_string_dtype` that would not return ``True`` for :class:`StringDtype` or :class:`ArrowDtype` with ``pyarrow.string()`` (:issue:`15585`)
10251025
- Bug in converting string dtypes to "datetime64[ns]" or "timedelta64[ns]" incorrectly raising ``TypeError`` (:issue:`36153`)
10261026
-
10271027

pandas/core/arrays/arrow/dtype.py

+3
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def name(self) -> str: # type: ignore[override]
9595
@cache_readonly
9696
def numpy_dtype(self) -> np.dtype:
9797
"""Return an instance of the related numpy dtype"""
98+
if pa.types.is_string(self.pyarrow_dtype):
99+
# pa.string().to_pandas_dtype() = object which we don't want
100+
return np.dtype(str)
98101
try:
99102
return np.dtype(self.pyarrow_dtype.to_pandas_dtype())
100103
except (NotImplementedError, TypeError):

pandas/tests/extension/test_arrow.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
is_integer_dtype,
4747
is_numeric_dtype,
4848
is_signed_integer_dtype,
49+
is_string_dtype,
4950
is_unsigned_integer_dtype,
5051
)
5152
from pandas.tests.extension import base
@@ -651,6 +652,24 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
651652
):
652653
super().test_groupby_extension_agg(as_index, data_for_grouping)
653654

655+
def test_in_numeric_groupby(self, data_for_grouping):
656+
if is_string_dtype(data_for_grouping.dtype):
657+
df = pd.DataFrame(
658+
{
659+
"A": [1, 1, 2, 2, 3, 3, 1, 4],
660+
"B": data_for_grouping,
661+
"C": [1, 1, 1, 1, 1, 1, 1, 1],
662+
}
663+
)
664+
665+
expected = pd.Index(["C"])
666+
with pytest.raises(TypeError, match="does not support"):
667+
df.groupby("A").sum().columns
668+
result = df.groupby("A").sum(numeric_only=True).columns
669+
tm.assert_index_equal(result, expected)
670+
else:
671+
super().test_in_numeric_groupby(data_for_grouping)
672+
654673

655674
class TestBaseDtype(base.BaseDtypeTests):
656675
def test_construct_from_string_own_name(self, dtype, request):
@@ -730,7 +749,6 @@ def test_get_common_dtype(self, dtype, request):
730749
and (pa_dtype.unit != "ns" or pa_dtype.tz is not None)
731750
)
732751
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
733-
or pa.types.is_string(pa_dtype)
734752
or pa.types.is_binary(pa_dtype)
735753
):
736754
request.node.add_marker(
@@ -743,6 +761,13 @@ def test_get_common_dtype(self, dtype, request):
743761
)
744762
super().test_get_common_dtype(dtype)
745763

764+
def test_is_not_string_type(self, dtype):
765+
pa_dtype = dtype.pyarrow_dtype
766+
if pa.types.is_string(pa_dtype):
767+
assert is_string_dtype(dtype)
768+
else:
769+
super().test_is_not_string_type(dtype)
770+
746771

747772
class TestBaseIndex(base.BaseIndexTests):
748773
pass

0 commit comments

Comments
 (0)