Skip to content

Commit 61823bd

Browse files
rohanjain101Rohan Jain
authored andcommitted
Series.str.find fix for pd.ArrowDtype(pa.string()) (pandas-dev#56792)
* fix find * gh reference * add test for Nones * fix min version compat * restore test * improve test cases * fix empty string * inline * improve tests * fix * Revert "fix" This reverts commit 7fa21eb. * fix * merge * inline --------- Co-authored-by: Rohan Jain <[email protected]>
1 parent 024a503 commit 61823bd

File tree

2 files changed

+90
-18
lines changed

2 files changed

+90
-18
lines changed

pandas/core/arrays/arrow/array.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -2364,20 +2364,26 @@ def _str_fullmatch(
23642364
return self._str_match(pat, case, flags, na)
23652365

23662366
def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
2367-
if start != 0 and end is not None:
2367+
if (start == 0 or start is None) and end is None:
2368+
result = pc.find_substring(self._pa_array, sub)
2369+
else:
2370+
if sub == "":
2371+
# GH 56792
2372+
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
2373+
return type(self)(pa.chunked_array(result))
2374+
if start is None:
2375+
start_offset = 0
2376+
start = 0
2377+
elif start < 0:
2378+
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
2379+
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
2380+
else:
2381+
start_offset = start
23682382
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
23692383
result = pc.find_substring(slices, sub)
2370-
not_found = pc.equal(result, -1)
2371-
start_offset = max(0, start)
2384+
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
23722385
offset_result = pc.add(result, start_offset)
2373-
result = pc.if_else(not_found, result, offset_result)
2374-
elif start == 0 and end is None:
2375-
slices = self._pa_array
2376-
result = pc.find_substring(slices, sub)
2377-
else:
2378-
raise NotImplementedError(
2379-
f"find not implemented with {sub=}, {start=}, {end=}"
2380-
)
2386+
result = pc.if_else(found, offset_result, -1)
23812387
return type(self)(result)
23822388

23832389
def _str_join(self, sep: str) -> Self:

pandas/tests/extension/test_arrow.py

+73-7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
BytesIO,
2424
StringIO,
2525
)
26+
from itertools import combinations
2627
import operator
2728
import pickle
2829
import re
@@ -1933,13 +1934,18 @@ def test_str_fullmatch(pat, case, na, exp):
19331934

19341935

19351936
@pytest.mark.parametrize(
1936-
"sub, start, end, exp, exp_typ",
1937-
[["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [1, None], pa.int64()]],
1937+
"sub, start, end, exp, exp_type",
1938+
[
1939+
["ab", 0, None, [0, None], pa.int32()],
1940+
["bc", 1, 3, [1, None], pa.int64()],
1941+
["ab", 1, 3, [-1, None], pa.int64()],
1942+
["ab", -3, -3, [-1, None], pa.int64()],
1943+
],
19381944
)
1939-
def test_str_find(sub, start, end, exp, exp_typ):
1945+
def test_str_find(sub, start, end, exp, exp_type):
19401946
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
19411947
result = ser.str.find(sub, start=start, end=end)
1942-
expected = pd.Series(exp, dtype=ArrowDtype(exp_typ))
1948+
expected = pd.Series(exp, dtype=ArrowDtype(exp_type))
19431949
tm.assert_series_equal(result, expected)
19441950

19451951

@@ -1951,10 +1957,70 @@ def test_str_find_negative_start():
19511957
tm.assert_series_equal(result, expected)
19521958

19531959

1954-
def test_str_find_notimplemented():
1960+
def test_str_find_no_end():
19551961
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
1956-
with pytest.raises(NotImplementedError, match="find not implemented"):
1957-
ser.str.find("ab", start=1)
1962+
if pa_version_under13p0:
1963+
# https://github.com/apache/arrow/issues/36311
1964+
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
1965+
ser.str.find("ab", start=1)
1966+
else:
1967+
result = ser.str.find("ab", start=1)
1968+
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
1969+
tm.assert_series_equal(result, expected)
1970+
1971+
1972+
def test_str_find_negative_start_negative_end():
1973+
# GH 56791
1974+
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
1975+
result = ser.str.find(sub="d", start=-6, end=-3)
1976+
expected = pd.Series([3, None], dtype=ArrowDtype(pa.int64()))
1977+
tm.assert_series_equal(result, expected)
1978+
1979+
1980+
def test_str_find_large_start():
1981+
# GH 56791
1982+
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
1983+
if pa_version_under13p0:
1984+
# https://github.com/apache/arrow/issues/36311
1985+
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
1986+
ser.str.find(sub="d", start=16)
1987+
else:
1988+
result = ser.str.find(sub="d", start=16)
1989+
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
1990+
tm.assert_series_equal(result, expected)
1991+
1992+
1993+
@pytest.mark.skipif(
1994+
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
1995+
)
1996+
@pytest.mark.parametrize("start", list(range(-15, 15)) + [None])
1997+
@pytest.mark.parametrize("end", list(range(-15, 15)) + [None])
1998+
@pytest.mark.parametrize(
1999+
"sub",
2000+
["abcaadef"[x:y] for x, y in combinations(range(len("abcaadef") + 1), r=2)]
2001+
+ [
2002+
"",
2003+
"az",
2004+
"abce",
2005+
],
2006+
)
2007+
def test_str_find_e2e(start, end, sub):
2008+
s = pd.Series(
2009+
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
2010+
dtype=ArrowDtype(pa.string()),
2011+
)
2012+
object_series = s.astype(pd.StringDtype())
2013+
result = s.str.find(sub, start, end)
2014+
expected = object_series.str.find(sub, start, end).astype(result.dtype)
2015+
tm.assert_series_equal(result, expected)
2016+
2017+
2018+
def test_str_find_negative_start_negative_end_no_match():
2019+
# GH 56791
2020+
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
2021+
result = ser.str.find(sub="d", start=-3, end=-6)
2022+
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
2023+
tm.assert_series_equal(result, expected)
19582024

19592025

19602026
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)