Skip to content

Commit 0de96cb

Browse files
author
Rohan Jain
committed
fix empty string
1 parent 771e2f6 commit 0de96cb

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

pandas/core/arrays/arrow/array.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -2331,20 +2331,32 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
23312331
if (start == 0 or start is None) and end is None:
23322332
result = pc.find_substring(self._pa_array, sub)
23332333
else:
2334+
length = pc.utf8_length(self._pa_array)
23342335
if start is None:
23352336
start_offset = 0
23362337
start = 0
23372338
elif start < 0:
2338-
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
2339+
start_offset = pc.add(start, length)
23392340
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
23402341
else:
23412342
start_offset = start
23422343
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
23432344
result = pc.find_substring(slices, sub)
2344-
not_found = pc.equal(result, pa.scalar(-1, type=result.type))
2345-
2345+
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
2346+
if end is None:
2347+
end = length
2348+
elif end < 0:
2349+
end = pc.add(end, length)
2350+
end = pc.if_else(pc.less(end, 0), 0, end)
2351+
found = pc.and_(
2352+
found,
2353+
pc.and_(
2354+
pc.less_equal(start_offset, end),
2355+
pc.less_equal(start_offset, length),
2356+
),
2357+
)
23462358
offset_result = pc.add(result, start_offset)
2347-
result = pc.if_else(not_found, result, offset_result)
2359+
result = pc.if_else(found, offset_result, -1)
23482360
return type(self)(result)
23492361

23502362
def _str_join(self, sep: str) -> Self:

pandas/tests/extension/test_arrow.py

+22
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
BytesIO,
2424
StringIO,
2525
)
26+
from itertools import combinations
2627
import operator
2728
import pickle
2829
import re
@@ -1975,6 +1976,27 @@ def test_str_find_large_start():
19751976
tm.assert_series_equal(result, expected)
19761977

19771978

1979+
def _get_all_substrings(string):
1980+
length = len(string) + 1
1981+
return [string[x:y] for x, y in combinations(range(length), r=2)]
1982+
1983+
1984+
@pytest.mark.xfail(
1985+
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
1986+
)
1987+
def test_str_find_e2e():
1988+
string = "abcdefgh"
1989+
s = pd.Series([string], dtype=ArrowDtype(pa.string()))
1990+
substrings = _get_all_substrings(string) + ["", "az", "abce"]
1991+
offsets = list(range(-15, 15)) + [None]
1992+
for start in offsets:
1993+
for end in offsets:
1994+
for sub in substrings:
1995+
result = s.str.find(sub, start, end)
1996+
expected = pd.Series([string.find(sub, start, end)], dtype=result.dtype)
1997+
tm.assert_series_equal(result, expected)
1998+
1999+
19782000
def test_str_find_negative_start_negative_end_no_match():
19792001
# GH 56791
19802002
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))

0 commit comments

Comments
 (0)