Skip to content

Commit 807d8d5

Browse files
jbrockmendeljorisvandenbossche
authored andcommitted
REF (string): de-duplicate str_endswith, startswith (#59568)
1 parent d64b8d8 commit 807d8d5

File tree

3 files changed

+49
-72
lines changed

3 files changed

+49
-72
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
from __future__ import annotations
22

3-
from typing import Literal
3+
from typing import (
4+
TYPE_CHECKING,
5+
Literal,
6+
)
47

58
import numpy as np
69

710
from pandas.compat import pa_version_under10p1
811

12+
from pandas.core.dtypes.missing import isna
13+
914
if not pa_version_under10p1:
1015
import pyarrow as pa
1116
import pyarrow.compute as pc
1217

18+
if TYPE_CHECKING:
19+
from collections.abc import Sized
20+
21+
from pandas._typing import Scalar
22+
1323

1424
class ArrowStringArrayMixin:
15-
_pa_array = None
25+
_pa_array: Sized
1626

1727
def __init__(self, *args, **kwargs) -> None:
1828
raise NotImplementedError
@@ -90,3 +100,37 @@ def _str_removesuffix(self, suffix: str):
90100
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
91101
result = pc.if_else(ends_with, removed, self._pa_array)
92102
return type(self)(result)
103+
104+
def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
105+
if isinstance(pat, str):
106+
result = pc.starts_with(self._pa_array, pattern=pat)
107+
else:
108+
if len(pat) == 0:
109+
# For empty tuple we return null for missing values and False
110+
# for valid values.
111+
result = pc.if_else(pc.is_null(self._pa_array), None, False)
112+
else:
113+
result = pc.starts_with(self._pa_array, pattern=pat[0])
114+
115+
for p in pat[1:]:
116+
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
117+
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
118+
result = result.fill_null(na)
119+
return self._convert_bool_result(result)
120+
121+
def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
122+
if isinstance(pat, str):
123+
result = pc.ends_with(self._pa_array, pattern=pat)
124+
else:
125+
if len(pat) == 0:
126+
# For empty tuple we return null for missing values and False
127+
# for valid values.
128+
result = pc.if_else(pc.is_null(self._pa_array), None, False)
129+
else:
130+
result = pc.ends_with(self._pa_array, pattern=pat[0])
131+
132+
for p in pat[1:]:
133+
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
134+
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
135+
result = result.fill_null(na)
136+
return self._convert_bool_result(result)

pandas/core/arrays/arrow/array.py

+1-32
Original file line numberDiff line numberDiff line change
@@ -2311,38 +2311,7 @@ def _str_contains(
23112311
result = result.fill_null(na)
23122312
return type(self)(result)
23132313

2314-
def _str_startswith(self, pat: str | tuple[str, ...], na=None):
2315-
if isinstance(pat, str):
2316-
result = pc.starts_with(self._pa_array, pattern=pat)
2317-
else:
2318-
if len(pat) == 0:
2319-
# For empty tuple, pd.StringDtype() returns null for missing values
2320-
# and false for valid values.
2321-
result = pc.if_else(pc.is_null(self._pa_array), None, False)
2322-
else:
2323-
result = pc.starts_with(self._pa_array, pattern=pat[0])
2324-
2325-
for p in pat[1:]:
2326-
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
2327-
if not isna(na):
2328-
result = result.fill_null(na)
2329-
return type(self)(result)
2330-
2331-
def _str_endswith(self, pat: str | tuple[str, ...], na=None):
2332-
if isinstance(pat, str):
2333-
result = pc.ends_with(self._pa_array, pattern=pat)
2334-
else:
2335-
if len(pat) == 0:
2336-
# For empty tuple, pd.StringDtype() returns null for missing values
2337-
# and false for valid values.
2338-
result = pc.if_else(pc.is_null(self._pa_array), None, False)
2339-
else:
2340-
result = pc.ends_with(self._pa_array, pattern=pat[0])
2341-
2342-
for p in pat[1:]:
2343-
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
2344-
if not isna(na):
2345-
result = result.fill_null(na)
2314+
def _result_converter(self, result):
23462315
return type(self)(result)
23472316

23482317
def _str_replace(

pandas/core/arrays/string_arrow.py

+2-38
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@ def _data(self):
284284
# String methods interface
285285

286286
_str_map = BaseStringArray._str_map
287+
_str_startswith = ArrowStringArrayMixin._str_startswith
288+
_str_endswith = ArrowStringArrayMixin._str_endswith
287289

288290
def _str_contains(
289291
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
@@ -301,44 +303,6 @@ def _str_contains(
301303
result[isna(result)] = bool(na)
302304
return result
303305

304-
def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
305-
if isinstance(pat, str):
306-
result = pc.starts_with(self._pa_array, pattern=pat)
307-
else:
308-
if len(pat) == 0:
309-
# mimic existing behaviour of string extension array
310-
# and python string method
311-
result = pa.array(
312-
np.zeros(len(self._pa_array), dtype=bool), mask=isna(self._pa_array)
313-
)
314-
else:
315-
result = pc.starts_with(self._pa_array, pattern=pat[0])
316-
317-
for p in pat[1:]:
318-
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
319-
if not isna(na):
320-
result = result.fill_null(na)
321-
return self._convert_bool_result(result)
322-
323-
def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
324-
if isinstance(pat, str):
325-
result = pc.ends_with(self._pa_array, pattern=pat)
326-
else:
327-
if len(pat) == 0:
328-
# mimic existing behaviour of string extension array
329-
# and python string method
330-
result = pa.array(
331-
np.zeros(len(self._pa_array), dtype=bool), mask=isna(self._pa_array)
332-
)
333-
else:
334-
result = pc.ends_with(self._pa_array, pattern=pat[0])
335-
336-
for p in pat[1:]:
337-
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
338-
if not isna(na):
339-
result = result.fill_null(na)
340-
return self._convert_bool_result(result)
341-
342306
def _str_replace(
343307
self,
344308
pat: str | re.Pattern,

0 commit comments

Comments
 (0)