Skip to content

Series.str.find fix for pd.ArrowDtype(pa.string()) #56792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 2, 2024
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ Strings
- Bug in :meth:`DataFrame.reindex` not matching :class:`Index` with ``string[pyarrow_numpy]`` dtype (:issue:`56106`)
- Bug in :meth:`Index.str.cat` always casting result to object dtype (:issue:`56157`)
- Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`)
- Bug in :meth:`Series.str.find` when ``start < 0`` and ``end < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56791`)
- Bug in :meth:`Series.str.find` when ``start < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56411`)
- Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`)
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for :class:`ArrowDtype` with ``pyarrow.string`` dtype (:issue:`56579`)
Expand Down
23 changes: 13 additions & 10 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,20 +2328,23 @@ def _str_fullmatch(
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
if start != 0 and end is not None:
if (start == 0 or start is None) and end is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance you've looked at using a slice object here already? It feels like that could help simplify a lot of the branching being done here for things being None / 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a slice object would help here, since I believe for this method, we need a positive start index, to calculate the offset of the substring from the original string. This PR essentially converts start into equivalent, but positive index, then it can be added to the resulting index returned by pc.find_substring. For example, if string is "abc" and start is -1, we calculate start_offset as 2, which is the equivalent positive index. We can use this offset to find the index of the substring in the original string, by adding it the position of the substring result.

result = pc.find_substring(self._pa_array, sub)
else:
if start is None:
start_offset = 0
start = 0
elif start < 0:
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
else:
start_offset = start
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
start_offset = max(0, start)
not_found = pc.equal(result, pa.scalar(-1, type=result.type))

offset_result = pc.add(result, start_offset)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
raise NotImplementedError(
f"find not implemented with {sub=}, {start=}, {end=}"
)
return type(self)(result)

def _str_join(self, sep: str) -> Self:
Expand Down
24 changes: 16 additions & 8 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,13 +1919,19 @@ def test_str_fullmatch(pat, case, na, exp):


@pytest.mark.parametrize(
"sub, start, end, exp, exp_typ",
[["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [1, None], pa.int64()]],
"sub, start, end, exp, exp_type",
[
["ab", 0, None, [0, None], pa.int32()],
["bc", 1, 3, [1, None], pa.int64()],
["ab", 1, None, [-1, None], pa.int64()],
["ab", -3, -3, [-1, None], pa.int64()],
["ab", None, None, [0, None], pa.int32()],
],
)
def test_str_find(sub, start, end, exp, exp_typ):
def test_str_find(sub, start, end, exp, exp_type):
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
result = ser.str.find(sub, start=start, end=end)
expected = pd.Series(exp, dtype=ArrowDtype(exp_typ))
expected = pd.Series(exp, dtype=ArrowDtype(exp_type))
tm.assert_series_equal(result, expected)


Expand All @@ -1937,10 +1943,12 @@ def test_str_find_negative_start():
tm.assert_series_equal(result, expected)


def test_str_find_notimplemented():
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
with pytest.raises(NotImplementedError, match="find not implemented"):
ser.str.find("ab", start=1)
def test_str_find_negative_start_negative_end():
# GH 56791
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
result = ser.str.find(sub="d", start=-6, end=-3)
expected = pd.Series([3, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
Expand Down