Skip to content

Commit a9e30c5

Browse files
String dtype: map builtin str alias to StringDtype (#59685)
* String dtype: map builtin str alias to StringDtype * fix tests * fix datetimelike astype and more tests * remove xfails * try fix typing * fix copy_view tests * fix remaining tests with infer_string enabled * ignore typing issue for now * move to common.py * simplify Categorical._str_get_dummies * small cleanup * fix ensure_string_array to not modify extension arrays inplace * fix ensure_string_array once more + fix is_extension_array_dtype for str * still xfail TestArrowArray::test_astype_str when not using infer_string * ensure maybe_convert_objects copies object dtype input array when inferring StringDtype * update test_1d_object_array_does_not_copy test * update constructor copy test + do not copy in maybe_convert_objects? * skip str.get_dummies test for now * use pandas_dtype() instead of registry.find * fix corner cases for calling pandas_dtype * add TODO comment in ensure_string_array
1 parent 5b6997c commit a9e30c5

32 files changed

+185
-111
lines changed

pandas/_libs/lib.pyx

+8-1
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,14 @@ cpdef ndarray[object] ensure_string_array(
754754

755755
if hasattr(arr, "to_numpy"):
756756

757-
if hasattr(arr, "dtype") and arr.dtype.kind in "mM":
757+
if (
758+
hasattr(arr, "dtype")
759+
and arr.dtype.kind in "mM"
760+
# TODO: we should add a custom ArrowExtensionArray.astype implementation
761+
# that handles astype(str) specifically, avoiding ending up here and
762+
# then we can remove the below check for `_pa_array` (for ArrowEA)
763+
and not hasattr(arr, "_pa_array")
764+
):
758765
# dtype check to exclude DataFrame
759766
# GH#41409 TODO: not a great place for this
760767
out = arr.astype(str).astype(object)

pandas/_testing/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108

109109
COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
110110
if using_string_dtype():
111-
STRING_DTYPES: list[Dtype] = [str, "U"]
111+
STRING_DTYPES: list[Dtype] = ["U"]
112112
else:
113113
STRING_DTYPES: list[Dtype] = [str, "str", "U"] # type: ignore[no-redef]
114114
COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES]

pandas/core/arrays/categorical.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2685,7 +2685,9 @@ def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
26852685
# sep may not be in categories. Just bail on this.
26862686
from pandas.core.arrays import NumpyExtensionArray
26872687

2688-
return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep, dtype)
2688+
return NumpyExtensionArray(self.to_numpy(str, na_value="NaN"))._str_get_dummies(
2689+
sep, dtype
2690+
)
26892691

26902692
# ------------------------------------------------------------------------
26912693
# GroupBy Methods

pandas/core/arrays/datetimelike.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,16 @@ def astype(self, dtype, copy: bool = True):
471471

472472
return self._box_values(self.asi8.ravel()).reshape(self.shape)
473473

474+
elif is_string_dtype(dtype):
475+
if isinstance(dtype, ExtensionDtype):
476+
arr_object = self._format_native_types(na_rep=dtype.na_value) # type: ignore[arg-type]
477+
cls = dtype.construct_array_type()
478+
return cls._from_sequence(arr_object, dtype=dtype, copy=False)
479+
else:
480+
return self._format_native_types()
481+
474482
elif isinstance(dtype, ExtensionDtype):
475483
return super().astype(dtype, copy=copy)
476-
elif is_string_dtype(dtype):
477-
return self._format_native_types()
478484
elif dtype.kind in "iu":
479485
# we deliberately ignore int32 vs. int64 here.
480486
# See https://github.com/pandas-dev/pandas/issues/24381 for more.

pandas/core/dtypes/common.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import numpy as np
1414

15+
from pandas._config import using_string_dtype
16+
1517
from pandas._libs import (
1618
Interval,
1719
Period,
@@ -1470,7 +1472,15 @@ def is_extension_array_dtype(arr_or_dtype) -> bool:
14701472
elif isinstance(dtype, np.dtype):
14711473
return False
14721474
else:
1473-
return registry.find(dtype) is not None
1475+
try:
1476+
with warnings.catch_warnings():
1477+
# pandas_dtype(..) can raise UserWarning for class input
1478+
warnings.simplefilter("ignore", UserWarning)
1479+
dtype = pandas_dtype(dtype)
1480+
except (TypeError, ValueError):
1481+
# np.dtype(..) can raise ValueError
1482+
return False
1483+
return isinstance(dtype, ExtensionDtype)
14741484

14751485

14761486
def is_ea_or_datetimelike_dtype(dtype: DtypeObj | None) -> bool:
@@ -1773,6 +1783,12 @@ def pandas_dtype(dtype) -> DtypeObj:
17731783
elif isinstance(dtype, (np.dtype, ExtensionDtype)):
17741784
return dtype
17751785

1786+
# builtin aliases
1787+
if dtype is str and using_string_dtype():
1788+
from pandas.core.arrays.string_ import StringDtype
1789+
1790+
return StringDtype(na_value=np.nan)
1791+
17761792
# registered extension types
17771793
result = registry.find(dtype)
17781794
if result is not None:

pandas/core/indexes/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6262,7 +6262,11 @@ def _should_compare(self, other: Index) -> bool:
62626262
return False
62636263

62646264
dtype = _unpack_nested_dtype(other)
6265-
return self._is_comparable_dtype(dtype) or is_object_dtype(dtype)
6265+
return (
6266+
self._is_comparable_dtype(dtype)
6267+
or is_object_dtype(dtype)
6268+
or is_string_dtype(dtype)
6269+
)
62666270

62676271
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
62686272
"""

pandas/core/indexes/interval.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
is_number,
5252
is_object_dtype,
5353
is_scalar,
54+
is_string_dtype,
5455
pandas_dtype,
5556
)
5657
from pandas.core.dtypes.dtypes import (
@@ -712,7 +713,7 @@ def _get_indexer(
712713
# left/right get_indexer, compare elementwise, equality -> match
713714
indexer = self._get_indexer_unique_sides(target)
714715

715-
elif not is_object_dtype(target.dtype):
716+
elif not (is_object_dtype(target.dtype) or is_string_dtype(target.dtype)):
716717
# homogeneous scalar index: use IntervalTree
717718
# we should always have self._should_partial_index(target) here
718719
target = self._maybe_convert_i8(target)

pandas/tests/arrays/floating/test_astype.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,9 @@ def test_astype_str(using_infer_string):
6868

6969
if using_infer_string:
7070
expected = pd.array(["0.1", "0.2", None], dtype=pd.StringDtype(na_value=np.nan))
71-
tm.assert_extension_array_equal(a.astype("str"), expected)
7271

73-
# TODO(infer_string) this should also be a string array like above
74-
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
75-
tm.assert_numpy_array_equal(a.astype(str), expected)
72+
tm.assert_extension_array_equal(a.astype(str), expected)
73+
tm.assert_extension_array_equal(a.astype("str"), expected)
7674
else:
7775
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
7876

pandas/tests/arrays/integer/test_dtypes.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,9 @@ def test_astype_str(using_infer_string):
281281

282282
if using_infer_string:
283283
expected = pd.array(["1", "2", None], dtype=pd.StringDtype(na_value=np.nan))
284-
tm.assert_extension_array_equal(a.astype("str"), expected)
285284

286-
# TODO(infer_string) this should also be a string array like above
287-
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
288-
tm.assert_numpy_array_equal(a.astype(str), expected)
285+
tm.assert_extension_array_equal(a.astype(str), expected)
286+
tm.assert_extension_array_equal(a.astype("str"), expected)
289287
else:
290288
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
291289

pandas/tests/arrays/sparse/test_astype.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_astype_all(self, any_real_numpy_dtype):
8181
),
8282
(
8383
SparseArray([0, 1, 10]),
84-
str,
85-
SparseArray(["0", "1", "10"], dtype=SparseDtype(str, "0")),
84+
np.str_,
85+
SparseArray(["0", "1", "10"], dtype=SparseDtype(np.str_, "0")),
8686
),
8787
(SparseArray(["10", "20"]), float, SparseArray([10.0, 20.0])),
8888
(

pandas/tests/arrays/sparse/test_dtype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_construct_from_string_fill_value_raises(string):
184184
[
185185
(SparseDtype(int, 0), float, SparseDtype(float, 0.0)),
186186
(SparseDtype(int, 1), float, SparseDtype(float, 1.0)),
187-
(SparseDtype(int, 1), str, SparseDtype(object, "1")),
187+
(SparseDtype(int, 1), np.str_, SparseDtype(object, "1")),
188188
(SparseDtype(float, 1.5), int, SparseDtype(int, 1)),
189189
],
190190
)

pandas/tests/dtypes/test_common.py

+12
Original file line numberDiff line numberDiff line change
@@ -810,11 +810,23 @@ def test_pandas_dtype_string_dtypes(string_storage):
810810
"pyarrow" if HAS_PYARROW else "python", na_value=np.nan
811811
)
812812

813+
with pd.option_context("future.infer_string", True):
814+
# with the default string_storage setting
815+
result = pandas_dtype(str)
816+
assert result == pd.StringDtype(
817+
"pyarrow" if HAS_PYARROW else "python", na_value=np.nan
818+
)
819+
813820
with pd.option_context("future.infer_string", True):
814821
with pd.option_context("string_storage", string_storage):
815822
result = pandas_dtype("str")
816823
assert result == pd.StringDtype(string_storage, na_value=np.nan)
817824

825+
with pd.option_context("future.infer_string", True):
826+
with pd.option_context("string_storage", string_storage):
827+
result = pandas_dtype(str)
828+
assert result == pd.StringDtype(string_storage, na_value=np.nan)
829+
818830
with pd.option_context("future.infer_string", False):
819831
with pd.option_context("string_storage", string_storage):
820832
result = pandas_dtype("str")

pandas/tests/extension/base/casting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def test_tolist(self, data):
4444
assert result == expected
4545

4646
def test_astype_str(self, data):
47-
result = pd.Series(data[:5]).astype(str)
48-
expected = pd.Series([str(x) for x in data[:5]], dtype=str)
47+
result = pd.Series(data[:2]).astype(str)
48+
expected = pd.Series([str(x) for x in data[:2]], dtype=str)
4949
tm.assert_series_equal(result, expected)
5050

5151
@pytest.mark.parametrize(

pandas/tests/extension/json/array.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,8 @@ def astype(self, dtype, copy=True):
208208
return self.copy()
209209
return self
210210
elif isinstance(dtype, StringDtype):
211-
value = self.astype(str) # numpy doesn't like nested dicts
212211
arr_cls = dtype.construct_array_type()
213-
return arr_cls._from_sequence(value, dtype=dtype, copy=False)
212+
return arr_cls._from_sequence(self, dtype=dtype, copy=False)
214213
elif not copy:
215214
return np.asarray([dict(x) for x in self], dtype=dtype)
216215
else:

pandas/tests/extension/test_arrow.py

+5-24
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
pa_version_under13p0,
4444
pa_version_under14p0,
4545
)
46-
import pandas.util._test_decorators as td
4746

4847
from pandas.core.dtypes.dtypes import (
4948
ArrowDtype,
@@ -292,43 +291,25 @@ def test_map(self, data_missing, na_action):
292291
expected = data_missing.to_numpy()
293292
tm.assert_numpy_array_equal(result, expected)
294293

295-
def test_astype_str(self, data, request):
294+
def test_astype_str(self, data, request, using_infer_string):
296295
pa_dtype = data.dtype.pyarrow_dtype
297296
if pa.types.is_binary(pa_dtype):
298297
request.applymarker(
299298
pytest.mark.xfail(
300299
reason=f"For {pa_dtype} .astype(str) decodes.",
301300
)
302301
)
303-
elif (
304-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
305-
) or pa.types.is_duration(pa_dtype):
302+
elif not using_infer_string and (
303+
(pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
304+
or pa.types.is_duration(pa_dtype)
305+
):
306306
request.applymarker(
307307
pytest.mark.xfail(
308308
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
309309
)
310310
)
311311
super().test_astype_str(data)
312312

313-
@pytest.mark.parametrize(
314-
"nullable_string_dtype",
315-
[
316-
"string[python]",
317-
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
318-
],
319-
)
320-
def test_astype_string(self, data, nullable_string_dtype, request):
321-
pa_dtype = data.dtype.pyarrow_dtype
322-
if (
323-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
324-
) or pa.types.is_duration(pa_dtype):
325-
request.applymarker(
326-
pytest.mark.xfail(
327-
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
328-
)
329-
)
330-
super().test_astype_string(data, nullable_string_dtype)
331-
332313
def test_from_dtype(self, data, request):
333314
pa_dtype = data.dtype.pyarrow_dtype
334315
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):

pandas/tests/frame/methods/test_astype.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -168,21 +168,21 @@ def test_astype_str(self):
168168
"d": list(map(str, d._values)),
169169
"e": list(map(str, e._values)),
170170
},
171-
dtype="object",
171+
dtype="str",
172172
)
173173

174174
tm.assert_frame_equal(result, expected)
175175

176-
def test_astype_str_float(self):
176+
def test_astype_str_float(self, using_infer_string):
177177
# see GH#11302
178178
result = DataFrame([np.nan]).astype(str)
179-
expected = DataFrame(["nan"], dtype="object")
179+
expected = DataFrame([np.nan if using_infer_string else "nan"], dtype="str")
180180

181181
tm.assert_frame_equal(result, expected)
182182
result = DataFrame([1.12345678901234567890]).astype(str)
183183

184184
val = "1.1234567890123457"
185-
expected = DataFrame([val], dtype="object")
185+
expected = DataFrame([val], dtype="str")
186186
tm.assert_frame_equal(result, expected)
187187

188188
@pytest.mark.parametrize("dtype_class", [dict, Series])
@@ -284,7 +284,7 @@ def test_astype_duplicate_col_series_arg(self):
284284
result = df.astype(dtypes)
285285
expected = DataFrame(
286286
{
287-
0: Series(vals[:, 0].astype(str), dtype=object),
287+
0: Series(vals[:, 0].astype(str), dtype="str"),
288288
1: vals[:, 1],
289289
2: pd.array(vals[:, 2], dtype="Float64"),
290290
3: vals[:, 3],
@@ -647,25 +647,26 @@ def test_astype_dt64tz(self, timezone_frame):
647647
# dt64tz->dt64 deprecated
648648
timezone_frame.astype("datetime64[ns]")
649649

650-
def test_astype_dt64tz_to_str(self, timezone_frame):
650+
def test_astype_dt64tz_to_str(self, timezone_frame, using_infer_string):
651651
# str formatting
652652
result = timezone_frame.astype(str)
653+
na_value = np.nan if using_infer_string else "NaT"
653654
expected = DataFrame(
654655
[
655656
[
656657
"2013-01-01",
657658
"2013-01-01 00:00:00-05:00",
658659
"2013-01-01 00:00:00+01:00",
659660
],
660-
["2013-01-02", "NaT", "NaT"],
661+
["2013-01-02", na_value, na_value],
661662
[
662663
"2013-01-03",
663664
"2013-01-03 00:00:00-05:00",
664665
"2013-01-03 00:00:00+01:00",
665666
],
666667
],
667668
columns=timezone_frame.columns,
668-
dtype="object",
669+
dtype="str",
669670
)
670671
tm.assert_frame_equal(result, expected)
671672

pandas/tests/frame/methods/test_select_dtypes.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def test_select_dtypes_include_using_list_like(self, using_infer_string):
9999
ei = df[["a"]]
100100
tm.assert_frame_equal(ri, ei)
101101

102+
ri = df.select_dtypes(include=[str])
103+
tm.assert_frame_equal(ri, ei)
104+
102105
def test_select_dtypes_exclude_using_list_like(self):
103106
df = DataFrame(
104107
{
@@ -358,7 +361,7 @@ def test_select_dtypes_datetime_with_tz(self):
358361
@pytest.mark.parametrize("dtype", [str, "str", np.bytes_, "S1", np.str_, "U1"])
359362
@pytest.mark.parametrize("arg", ["include", "exclude"])
360363
def test_select_dtypes_str_raises(self, dtype, arg, using_infer_string):
361-
if using_infer_string and dtype == "str":
364+
if using_infer_string and (dtype == "str" or dtype is str):
362365
# this is tested below
363366
pytest.skip("Selecting string columns works with future strings")
364367
df = DataFrame(

0 commit comments

Comments
 (0)