Skip to content

Commit a318356

Browse files
rohanjain101cbpygit
authored andcommitted
Support tuple in startswith/endswith for arrow strings (pandas-dev#56580)
1 parent 7db5c70 commit a318356

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ Strings
603603
- Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`)
604604
- Bug in :meth:`Series.str.find` when ``start < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56411`)
605605
- Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`)
606+
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for :class:`ArrowDtype` with ``pyarrow.string`` dtype (:issue:`56579`)
606607
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`)
607608

608609
Interval

pandas/core/arrays/arrow/array.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -2174,14 +2174,36 @@ def _str_contains(
21742174
result = result.fill_null(na)
21752175
return type(self)(result)
21762176

2177-
def _str_startswith(self, pat: str, na=None):
2178-
result = pc.starts_with(self._pa_array, pattern=pat)
2177+
def _str_startswith(self, pat: str | tuple[str, ...], na=None):
2178+
if isinstance(pat, str):
2179+
result = pc.starts_with(self._pa_array, pattern=pat)
2180+
else:
2181+
if len(pat) == 0:
2182+
# For empty tuple, pd.StringDtype() returns null for missing values
2183+
# and false for valid values.
2184+
result = pc.if_else(pc.is_null(self._pa_array), None, False)
2185+
else:
2186+
result = pc.starts_with(self._pa_array, pattern=pat[0])
2187+
2188+
for p in pat[1:]:
2189+
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
21792190
if not isna(na):
21802191
result = result.fill_null(na)
21812192
return type(self)(result)
21822193

2183-
def _str_endswith(self, pat: str, na=None):
2184-
result = pc.ends_with(self._pa_array, pattern=pat)
2194+
def _str_endswith(self, pat: str | tuple[str, ...], na=None):
2195+
if isinstance(pat, str):
2196+
result = pc.ends_with(self._pa_array, pattern=pat)
2197+
else:
2198+
if len(pat) == 0:
2199+
# For empty tuple, pd.StringDtype() returns null for missing values
2200+
# and false for valid values.
2201+
result = pc.if_else(pc.is_null(self._pa_array), None, False)
2202+
else:
2203+
result = pc.ends_with(self._pa_array, pattern=pat[0])
2204+
2205+
for p in pat[1:]:
2206+
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
21852207
if not isna(na):
21862208
result = result.fill_null(na)
21872209
return type(self)(result)

pandas/tests/extension/test_arrow.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -1801,19 +1801,33 @@ def test_str_contains_flags_unsupported():
18011801
@pytest.mark.parametrize(
18021802
"side, pat, na, exp",
18031803
[
1804-
["startswith", "ab", None, [True, None]],
1805-
["startswith", "b", False, [False, False]],
1806-
["endswith", "b", True, [False, True]],
1807-
["endswith", "bc", None, [True, None]],
1804+
["startswith", "ab", None, [True, None, False]],
1805+
["startswith", "b", False, [False, False, False]],
1806+
["endswith", "b", True, [False, True, False]],
1807+
["endswith", "bc", None, [True, None, False]],
1808+
["startswith", ("a", "e", "g"), None, [True, None, True]],
1809+
["endswith", ("a", "c", "g"), None, [True, None, True]],
1810+
["startswith", (), None, [False, None, False]],
1811+
["endswith", (), None, [False, None, False]],
18081812
],
18091813
)
18101814
def test_str_start_ends_with(side, pat, na, exp):
1811-
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
1815+
ser = pd.Series(["abc", None, "efg"], dtype=ArrowDtype(pa.string()))
18121816
result = getattr(ser.str, side)(pat, na=na)
18131817
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
18141818
tm.assert_series_equal(result, expected)
18151819

18161820

1821+
@pytest.mark.parametrize("side", ("startswith", "endswith"))
1822+
def test_str_starts_ends_with_all_nulls_empty_tuple(side):
1823+
ser = pd.Series([None, None], dtype=ArrowDtype(pa.string()))
1824+
result = getattr(ser.str, side)(())
1825+
1826+
# bool datatype preserved for all nulls.
1827+
expected = pd.Series([None, None], dtype=ArrowDtype(pa.bool_()))
1828+
tm.assert_series_equal(result, expected)
1829+
1830+
18171831
@pytest.mark.parametrize(
18181832
"arg_name, arg",
18191833
[["pat", re.compile("b")], ["repl", str], ["case", False], ["flags", 1]],

0 commit comments

Comments
 (0)