Skip to content

Commit 7415aca

Browse files
String dtype: disallow specifying the 'str' dtype with storage in [..] in string alias (#60661)
1 parent 57d2489 commit 7415aca

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

pandas/core/dtypes/dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2339,7 +2339,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
23392339
)
23402340
if not string.endswith("[pyarrow]"):
23412341
raise TypeError(f"'{string}' must end with '[pyarrow]'")
2342-
if string == "string[pyarrow]":
2342+
if string in ("string[pyarrow]", "str[pyarrow]"):
23432343
# Ensure Registry.find skips ArrowDtype to use StringDtype instead
23442344
raise TypeError("string[pyarrow] should be constructed by StringDtype")
23452345
if pa_version_under10p1:

pandas/tests/dtypes/test_common.py

+20
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,26 @@ def test_pandas_dtype_string_dtypes(string_storage):
837837
assert result == pd.StringDtype(string_storage, na_value=pd.NA)
838838

839839

840+
def test_pandas_dtype_string_dtype_alias_with_storage():
841+
with pytest.raises(TypeError, match="not understood"):
842+
pandas_dtype("str[python]")
843+
844+
with pytest.raises(TypeError, match="not understood"):
845+
pandas_dtype("str[pyarrow]")
846+
847+
result = pandas_dtype("string[python]")
848+
assert result == pd.StringDtype("python", na_value=pd.NA)
849+
850+
if HAS_PYARROW:
851+
result = pandas_dtype("string[pyarrow]")
852+
assert result == pd.StringDtype("pyarrow", na_value=pd.NA)
853+
else:
854+
with pytest.raises(
855+
ImportError, match="required for PyArrow backed StringArray"
856+
):
857+
pandas_dtype("string[pyarrow]")
858+
859+
840860
@td.skip_if_installed("pyarrow")
841861
def test_construct_from_string_without_pyarrow_installed():
842862
# GH 57928

pandas/tests/strings/test_get_dummies.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas.util._test_decorators as td
77

88
from pandas import (
9+
ArrowDtype,
910
DataFrame,
1011
Index,
1112
MultiIndex,
@@ -113,15 +114,17 @@ def test_get_dummies_with_str_dtype(any_string_dtype):
113114
# GH#47872
114115
@td.skip_if_no("pyarrow")
115116
def test_get_dummies_with_pa_str_dtype(any_string_dtype):
117+
import pyarrow as pa
118+
116119
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
117-
result = s.str.get_dummies("|", dtype="str[pyarrow]")
120+
result = s.str.get_dummies("|", dtype=ArrowDtype(pa.string()))
118121
expected = DataFrame(
119122
[
120123
["true", "true", "false"],
121124
["true", "false", "true"],
122125
["false", "false", "false"],
123126
],
124127
columns=list("abc"),
125-
dtype="str[pyarrow]",
128+
dtype=ArrowDtype(pa.string()),
126129
)
127130
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)