Skip to content

Commit 44325c1

Browse files
jbrockmendeljorisvandenbossche
authored andcommitted
BUG (string): ArrowStringArray.find corner cases (pandas-dev#59562)
1 parent 553780a commit 44325c1

File tree

4 files changed

+99
-32
lines changed

4 files changed

+99
-32
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from functools import partial
44
from typing import (
55
TYPE_CHECKING,
6+
Any,
67
Literal,
78
)
89

910
import numpy as np
1011

1112
from pandas.compat import (
1213
pa_version_under10p1,
14+
pa_version_under13p0,
1315
pa_version_under17p0,
1416
)
1517

@@ -20,7 +22,10 @@
2022
import pyarrow.compute as pc
2123

2224
if TYPE_CHECKING:
23-
from collections.abc import Sized
25+
from collections.abc import (
26+
Callable,
27+
Sized,
28+
)
2429

2530
from pandas._typing import Scalar
2631

@@ -39,6 +44,9 @@ def _convert_int_result(self, result):
3944
# Convert an integer-dtype result to the appropriate result type
4045
raise NotImplementedError
4146

47+
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
48+
raise NotImplementedError
49+
4250
def _str_pad(
4351
self,
4452
width: int,
@@ -201,3 +209,37 @@ def _str_contains(
201209
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
202210
result = result.fill_null(na)
203211
return self._convert_bool_result(result)
212+
213+
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
214+
if (
215+
pa_version_under13p0
216+
and not (start != 0 and end is not None)
217+
and not (start == 0 and end is None)
218+
):
219+
# GH#59562
220+
res_list = self._apply_elementwise(lambda val: val.find(sub, start, end))
221+
return self._convert_int_result(pa.chunked_array(res_list))
222+
223+
if (start == 0 or start is None) and end is None:
224+
result = pc.find_substring(self._pa_array, sub)
225+
else:
226+
if sub == "":
227+
# GH#56792
228+
res_list = self._apply_elementwise(
229+
lambda val: val.find(sub, start, end)
230+
)
231+
return self._convert_int_result(pa.chunked_array(res_list))
232+
if start is None:
233+
start_offset = 0
234+
start = 0
235+
elif start < 0:
236+
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
237+
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
238+
else:
239+
start_offset = start
240+
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
241+
result = pc.find_substring(slices, sub)
242+
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
243+
offset_result = pc.add(result, start_offset)
244+
result = pc.if_else(found, offset_result, -1)
245+
return self._convert_int_result(result)

pandas/core/arrays/arrow/array.py

-17
Original file line numberDiff line numberDiff line change
@@ -2348,23 +2348,6 @@ def _str_fullmatch(
23482348
pat = f"{pat}$"
23492349
return self._str_match(pat, case, flags, na)
23502350

2351-
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
2352-
if start != 0 and end is not None:
2353-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
2354-
result = pc.find_substring(slices, sub)
2355-
not_found = pc.equal(result, -1)
2356-
start_offset = max(0, start)
2357-
offset_result = pc.add(result, start_offset)
2358-
result = pc.if_else(not_found, result, offset_result)
2359-
elif start == 0 and end is None:
2360-
slices = self._pa_array
2361-
result = pc.find_substring(slices, sub)
2362-
else:
2363-
raise NotImplementedError(
2364-
f"find not implemented with {sub=}, {start=}, {end=}"
2365-
)
2366-
return type(self)(result)
2367-
23682351
def _str_join(self, sep: str):
23692352
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
23702353
self._pa_array.type

pandas/core/arrays/string_arrow.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -416,18 +416,14 @@ def _str_count(self, pat: str, flags: int = 0):
416416
return self._convert_int_result(result)
417417

418418
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
419-
if start != 0 and end is not None:
420-
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
421-
result = pc.find_substring(slices, sub)
422-
not_found = pc.equal(result, -1)
423-
offset_result = pc.add(result, end - start)
424-
result = pc.if_else(not_found, result, offset_result)
425-
elif start == 0 and end is None:
426-
slices = self._pa_array
427-
result = pc.find_substring(slices, sub)
428-
else:
419+
if (
420+
pa_version_under13p0
421+
and not (start != 0 and end is not None)
422+
and not (start == 0 and end is None)
423+
):
424+
# GH#59562
429425
return super()._str_find(sub, start, end)
430-
return self._convert_int_result(result)
426+
return ArrowStringArrayMixin._str_find(self, sub, start, end)
431427

432428
def _str_get_dummies(self, sep: str = "|"):
433429
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)

pandas/tests/extension/test_arrow.py

+49-3
Original file line numberDiff line numberDiff line change
@@ -1925,10 +1925,56 @@ def test_str_find_negative_start():
19251925
tm.assert_series_equal(result, expected)
19261926

19271927

1928-
def test_str_find_notimplemented():
1928+
def test_str_find_no_end():
19291929
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
1930-
with pytest.raises(NotImplementedError, match="find not implemented"):
1931-
ser.str.find("ab", start=1)
1930+
result = ser.str.find("ab", start=1)
1931+
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
1932+
tm.assert_series_equal(result, expected)
1933+
1934+
1935+
def test_str_find_negative_start_negative_end():
1936+
# GH 56791
1937+
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
1938+
result = ser.str.find(sub="d", start=-6, end=-3)
1939+
expected = pd.Series([3, None], dtype=ArrowDtype(pa.int64()))
1940+
tm.assert_series_equal(result, expected)
1941+
1942+
1943+
def test_str_find_large_start():
1944+
# GH 56791
1945+
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
1946+
result = ser.str.find(sub="d", start=16)
1947+
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
1948+
tm.assert_series_equal(result, expected)
1949+
1950+
1951+
@pytest.mark.skipif(
1952+
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
1953+
)
1954+
@pytest.mark.parametrize("start", [-15, -3, 0, 1, 15, None])
1955+
@pytest.mark.parametrize("end", [-15, -1, 0, 3, 15, None])
1956+
@pytest.mark.parametrize("sub", ["", "az", "abce", "a", "caa"])
1957+
def test_str_find_e2e(start, end, sub):
1958+
s = pd.Series(
1959+
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
1960+
dtype=ArrowDtype(pa.string()),
1961+
)
1962+
object_series = s.astype(pd.StringDtype(storage="python"))
1963+
result = s.str.find(sub, start, end)
1964+
expected = object_series.str.find(sub, start, end).astype(result.dtype)
1965+
tm.assert_series_equal(result, expected)
1966+
1967+
arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
1968+
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
1969+
tm.assert_series_equal(result2, expected)
1970+
1971+
1972+
def test_str_find_negative_start_negative_end_no_match():
1973+
# GH 56791
1974+
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
1975+
result = ser.str.find(sub="d", start=-3, end=-6)
1976+
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
1977+
tm.assert_series_equal(result, expected)
19321978

19331979

19341980
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)