Skip to content

Commit 03272ab

Browse files
Backport PR pandas-dev#51548 on branch 2.0.x (BUG: parsing pyarrow dtypes corner case) (pandas-dev#51606)
Backport PR pandas-dev#51548: BUG: parsing pyarrow dtypes corner case Co-authored-by: jbrockmendel <[email protected]>
1 parent 02e0cf6 commit 03272ab

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

doc/source/whatsnew/v2.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1420,7 +1420,7 @@ Metadata
14201420

14211421
Other
14221422
^^^^^
1423-
1423+
- Bug in incorrectly accepting dtype strings containing "[pyarrow]" more than once (:issue:`51548`)
14241424
- Bug in :meth:`Series.searchsorted` inconsistent behavior when accepting :class:`DataFrame` as parameter ``value`` (:issue:`49620`)
14251425
- Bug in :func:`array` failing to raise on :class:`DataFrame` inputs (:issue:`51167`)
14261426

pandas/core/arrays/arrow/dtype.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
197197
if string == "string[pyarrow]":
198198
# Ensure Registry.find skips ArrowDtype to use StringDtype instead
199199
raise TypeError("string[pyarrow] should be constructed by StringDtype")
200-
base_type = string.split("[pyarrow]")[0]
200+
201+
base_type = string[:-9] # get rid of "[pyarrow]"
201202
try:
202203
pa_dtype = pa.type_for_alias(base_type)
203204
except ValueError as err:

pandas/tests/extension/test_arrow.py

+11
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,17 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
12671267
ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]")
12681268

12691269

1270+
def test_arrowdtype_construct_from_string_type_only_one_pyarrow():
1271+
# GH#51225
1272+
invalid = "int64[pyarrow]foobar[pyarrow]"
1273+
msg = (
1274+
r"Passing pyarrow type specific parameters \(\[pyarrow\]\) in the "
1275+
r"string is not supported\."
1276+
)
1277+
with pytest.raises(NotImplementedError, match=msg):
1278+
pd.Series(range(3), dtype=invalid)
1279+
1280+
12701281
@pytest.mark.parametrize(
12711282
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
12721283
)

0 commit comments

Comments
 (0)