Skip to content

Commit 3f8d3e4

Browse files
authored
BUG (string): ArrowStringArray.find corner cases (#59562)
1 parent 08431f1 commit 3f8d3e4

File tree

4 files changed

+61
-55
lines changed

4 files changed

+61
-55
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from functools import partial
44
from typing import (
55
TYPE_CHECKING,
6+
Any,
67
Literal,
78
)
89

910
import numpy as np
1011

1112
from pandas.compat import (
1213
pa_version_under10p1,
14+
pa_version_under13p0,
1315
pa_version_under17p0,
1416
)
1517

@@ -20,7 +22,10 @@
2022
import pyarrow.compute as pc
2123

2224
if TYPE_CHECKING:
23-
from collections.abc import Sized
25+
from collections.abc import (
26+
Callable,
27+
Sized,
28+
)
2429

2530
from pandas._typing import (
2631
Scalar,
@@ -42,6 +47,9 @@ def _convert_int_result(self, result):
4247
# Convert an integer-dtype result to the appropriate result type
4348
raise NotImplementedError
4449

50+
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
51+
raise NotImplementedError
52+
4553
def _str_pad(
4654
self,
4755
width: int,
@@ -205,3 +213,37 @@ def _str_contains(
205213
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
206214
result = result.fill_null(na)
207215
return self._convert_bool_result(result)
216+
217+
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
218+
if (
219+
pa_version_under13p0
220+
and not (start != 0 and end is not None)
221+
and not (start == 0 and end is None)
222+
):
223+
# GH#59562
224+
res_list = self._apply_elementwise(lambda val: val.find(sub, start, end))
225+
return self._convert_int_result(pa.chunked_array(res_list))
226+
227+
if (start == 0 or start is None) and end is None:
228+
result = pc.find_substring(self._pa_array, sub)
229+
else:
230+
if sub == "":
231+
# GH#56792
232+
res_list = self._apply_elementwise(
233+
lambda val: val.find(sub, start, end)
234+
)
235+
return self._convert_int_result(pa.chunked_array(res_list))
236+
if start is None:
237+
start_offset = 0
238+
start = 0
239+
elif start < 0:
240+
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
241+
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
242+
else:
243+
start_offset = start
244+
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
245+
result = pc.find_substring(slices, sub)
246+
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
247+
offset_result = pc.add(result, start_offset)
248+
result = pc.if_else(found, offset_result, -1)
249+
return self._convert_int_result(result)

pandas/core/arrays/arrow/array.py

-23
Original file line numberDiff line numberDiff line change
@@ -2373,29 +2373,6 @@ def _str_fullmatch(
23732373
pat = f"{pat}$"
23742374
return self._str_match(pat, case, flags, na)
23752375

2376-
def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
2377-
if (start == 0 or start is None) and end is None:
2378-
result = pc.find_substring(self._pa_array, sub)
2379-
else:
2380-
if sub == "":
2381-
# GH 56792
2382-
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
2383-
return type(self)(pa.chunked_array(result))
2384-
if start is None:
2385-
start_offset = 0
2386-
start = 0
2387-
elif start < 0:
2388-
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
2389-
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
2390-
else:
2391-
start_offset = start
2392-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
2393-
result = pc.find_substring(slices, sub)
2394-
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
2395-
offset_result = pc.add(result, start_offset)
2396-
result = pc.if_else(found, offset_result, -1)
2397-
return type(self)(result)
2398-
23992376
def _str_join(self, sep: str) -> Self:
24002377
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
24012378
self._pa_array.type

pandas/core/arrays/string_arrow.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -416,18 +416,14 @@ def _str_count(self, pat: str, flags: int = 0):
416416
return self._convert_int_result(result)
417417

418418
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
419-
if start != 0 and end is not None:
420-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
421-
result = pc.find_substring(slices, sub)
422-
not_found = pc.equal(result, -1)
423-
offset_result = pc.add(result, end - start)
424-
result = pc.if_else(not_found, result, offset_result)
425-
elif start == 0 and end is None:
426-
slices = self._pa_array
427-
result = pc.find_substring(slices, sub)
428-
else:
419+
if (
420+
pa_version_under13p0
421+
and not (start != 0 and end is not None)
422+
and not (start == 0 and end is None)
423+
):
424+
# GH#59562
429425
return super()._str_find(sub, start, end)
430-
return self._convert_int_result(result)
426+
return ArrowStringArrayMixin._str_find(self, sub, start, end)
431427

432428
def _str_get_dummies(self, sep: str = "|"):
433429
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)

pandas/tests/extension/test_arrow.py

+11-20
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
import numpy as np
3333
import pytest
3434

35-
from pandas._config import using_string_dtype
36-
3735
from pandas._libs import lib
3836
from pandas._libs.tslibs import timezones
3937
from pandas.compat import (
@@ -1947,14 +1945,9 @@ def test_str_find_negative_start():
19471945

19481946
def test_str_find_no_end():
19491947
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
1950-
if pa_version_under13p0:
1951-
# https://github.com/apache/arrow/issues/36311
1952-
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
1953-
ser.str.find("ab", start=1)
1954-
else:
1955-
result = ser.str.find("ab", start=1)
1956-
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
1957-
tm.assert_series_equal(result, expected)
1948+
result = ser.str.find("ab", start=1)
1949+
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
1950+
tm.assert_series_equal(result, expected)
19581951

19591952

19601953
def test_str_find_negative_start_negative_end():
@@ -1968,17 +1961,11 @@ def test_str_find_negative_start_negative_end():
19681961
def test_str_find_large_start():
19691962
# GH 56791
19701963
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
1971-
if pa_version_under13p0:
1972-
# https://github.com/apache/arrow/issues/36311
1973-
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
1974-
ser.str.find(sub="d", start=16)
1975-
else:
1976-
result = ser.str.find(sub="d", start=16)
1977-
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
1978-
tm.assert_series_equal(result, expected)
1964+
result = ser.str.find(sub="d", start=16)
1965+
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
1966+
tm.assert_series_equal(result, expected)
19791967

19801968

1981-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
19821969
@pytest.mark.skipif(
19831970
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
19841971
)
@@ -1990,11 +1977,15 @@ def test_str_find_e2e(start, end, sub):
19901977
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
19911978
dtype=ArrowDtype(pa.string()),
19921979
)
1993-
object_series = s.astype(pd.StringDtype())
1980+
object_series = s.astype(pd.StringDtype(storage="python"))
19941981
result = s.str.find(sub, start, end)
19951982
expected = object_series.str.find(sub, start, end).astype(result.dtype)
19961983
tm.assert_series_equal(result, expected)
19971984

1985+
arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
1986+
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
1987+
tm.assert_series_equal(result2, expected)
1988+
19981989

19991990
def test_str_find_negative_start_negative_end_no_match():
20001991
# GH 56791

0 commit comments

Comments
 (0)