Skip to content

Commit e84b8e9

Browse files
String dtype: disallow specifying the 'str' dtype with storage in [..] in string alias (pandas-dev#60661)
(cherry picked from commit 7415aca)
1 parent e90bb0e commit e84b8e9

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

pandas/core/dtypes/dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2242,7 +2242,7 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
22422242
)
22432243
if not string.endswith("[pyarrow]"):
22442244
raise TypeError(f"'{string}' must end with '[pyarrow]'")
2245-
if string == "string[pyarrow]":
2245+
if string in ("string[pyarrow]", "str[pyarrow]"):
22462246
# Ensure Registry.find skips ArrowDtype to use StringDtype instead
22472247
raise TypeError("string[pyarrow] should be constructed by StringDtype")
22482248

pandas/tests/dtypes/test_common.py

+27
Original file line numberDiff line numberDiff line change
@@ -835,3 +835,30 @@ def test_pandas_dtype_string_dtypes(string_storage):
835835
with pd.option_context("string_storage", string_storage):
836836
result = pandas_dtype("string")
837837
assert result == pd.StringDtype(string_storage, na_value=pd.NA)
838+
839+
840+
def test_pandas_dtype_string_dtype_alias_with_storage():
841+
with pytest.raises(TypeError, match="not understood"):
842+
pandas_dtype("str[python]")
843+
844+
with pytest.raises(TypeError, match="not understood"):
845+
pandas_dtype("str[pyarrow]")
846+
847+
result = pandas_dtype("string[python]")
848+
assert result == pd.StringDtype("python", na_value=pd.NA)
849+
850+
if HAS_PYARROW:
851+
result = pandas_dtype("string[pyarrow]")
852+
assert result == pd.StringDtype("pyarrow", na_value=pd.NA)
853+
else:
854+
with pytest.raises(
855+
ImportError, match="required for PyArrow backed StringArray"
856+
):
857+
pandas_dtype("string[pyarrow]")
858+
859+
860+
@td.skip_if_installed("pyarrow")
861+
def test_construct_from_string_without_pyarrow_installed():
862+
# GH 57928
863+
with pytest.raises(ImportError, match="pyarrow>=10.0.1 is required"):
864+
pd.Series([-1.5, 0.2, None], dtype="float32[pyarrow]")

0 commit comments

Comments
 (0)