Skip to content

Commit 08886d1

Browse files
committed
FIX: added fix for str.startswith and str.endswith with tuple arg for "string[pyarrow]" dtype (GH#54942)
1 parent 0b5bc9c commit 08886d1

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

pandas/core/arrays/string_arrow.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,17 +330,47 @@ def _str_contains(
330330
result[isna(result)] = bool(na)
331331
return result
332332

333-
def _str_startswith(self, pat: str, na=None):
334-
result = pc.starts_with(self._pa_array, pattern=pat)
333+
def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
334+
if isinstance(pat, str):
335+
result = pc.starts_with(self._pa_array, pattern=pat)
336+
elif isinstance(pat, tuple) and all(isinstance(x, str) for x in pat):
337+
if len(pat) == 0:
338+
# mimic existing behaviour of string extension array
339+
# and python string method
340+
result = pa.array(
341+
np.full(len(self._pa_array), False), mask=isna(self._pa_array)
342+
)
343+
else:
344+
result = pc.starts_with(self._pa_array, pattern=pat[0])
345+
346+
for p in pat[1:]:
347+
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
348+
else:
349+
raise TypeError("pat must be str or tuple[str, ...]")
335350
if not isna(na):
336351
result = result.fill_null(na)
337352
result = self._result_converter(result)
338353
if not isna(na):
339354
result[isna(result)] = bool(na)
340355
return result
341356

342-
def _str_endswith(self, pat: str, na=None):
343-
result = pc.ends_with(self._pa_array, pattern=pat)
357+
def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
358+
if isinstance(pat, str):
359+
result = pc.ends_with(self._pa_array, pattern=pat)
360+
elif isinstance(pat, tuple) and all(isinstance(x, str) for x in pat):
361+
if len(pat) == 0:
362+
# mimic existing behaviour of string extension array
363+
# and python string method
364+
result = pa.array(
365+
np.full(len(self._pa_array), False), mask=isna(self._pa_array)
366+
)
367+
else:
368+
result = pc.ends_with(self._pa_array, pattern=pat[0])
369+
370+
for p in pat[1:]:
371+
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
372+
else:
373+
raise TypeError("pat must be of type str or tuple[str, ...]")
344374
if not isna(na):
345375
result = result.fill_null(na)
346376
result = self._result_converter(result)

0 commit comments

Comments
 (0)