Skip to content

Commit 4444e52

Browse files
authored
REF (string): de-duplicate ArrowStringArray methods (#59555)
1 parent 16b7288 commit 4444e52

File tree

3 files changed

+103
-172
lines changed

3 files changed

+103
-172
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+83
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from functools import partial
4+
import re
45
from typing import (
56
TYPE_CHECKING,
67
Any,
@@ -48,6 +49,37 @@ def _convert_int_result(self, result):
4849
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
4950
raise NotImplementedError
5051

52+
def _str_len(self):
53+
result = pc.utf8_length(self._pa_array)
54+
return self._convert_int_result(result)
55+
56+
def _str_lower(self) -> Self:
57+
return type(self)(pc.utf8_lower(self._pa_array))
58+
59+
def _str_upper(self) -> Self:
60+
return type(self)(pc.utf8_upper(self._pa_array))
61+
62+
def _str_strip(self, to_strip=None) -> Self:
63+
if to_strip is None:
64+
result = pc.utf8_trim_whitespace(self._pa_array)
65+
else:
66+
result = pc.utf8_trim(self._pa_array, characters=to_strip)
67+
return type(self)(result)
68+
69+
def _str_lstrip(self, to_strip=None) -> Self:
70+
if to_strip is None:
71+
result = pc.utf8_ltrim_whitespace(self._pa_array)
72+
else:
73+
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
74+
return type(self)(result)
75+
76+
def _str_rstrip(self, to_strip=None) -> Self:
77+
if to_strip is None:
78+
result = pc.utf8_rtrim_whitespace(self._pa_array)
79+
else:
80+
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
81+
return type(self)(result)
82+
5183
def _str_pad(
5284
self,
5385
width: int,
@@ -128,6 +160,33 @@ def _str_slice_replace(
128160
stop = np.iinfo(np.int64).max
129161
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))
130162

163+
def _str_replace(
164+
self,
165+
pat: str | re.Pattern,
166+
repl: str | Callable,
167+
n: int = -1,
168+
case: bool = True,
169+
flags: int = 0,
170+
regex: bool = True,
171+
) -> Self:
172+
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
173+
raise NotImplementedError(
174+
"replace is not supported with a re.Pattern, callable repl, "
175+
"case=False, or flags!=0"
176+
)
177+
178+
func = pc.replace_substring_regex if regex else pc.replace_substring
179+
# https://github.com/apache/arrow/issues/39149
180+
# GH 56404, unexpected behavior with negative max_replacements with pyarrow.
181+
pa_max_replacements = None if n < 0 else n
182+
result = func(
183+
self._pa_array,
184+
pattern=pat,
185+
replacement=repl,
186+
max_replacements=pa_max_replacements,
187+
)
188+
return type(self)(result)
189+
131190
def _str_capitalize(self) -> Self:
132191
return type(self)(pc.utf8_capitalize(self._pa_array))
133192

@@ -137,6 +196,16 @@ def _str_title(self) -> Self:
137196
def _str_swapcase(self) -> Self:
138197
return type(self)(pc.utf8_swapcase(self._pa_array))
139198

199+
def _str_removeprefix(self, prefix: str):
200+
if not pa_version_under13p0:
201+
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
202+
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
203+
result = pc.if_else(starts_with, removed, self._pa_array)
204+
return type(self)(result)
205+
predicate = lambda val: val.removeprefix(prefix)
206+
result = self._apply_elementwise(predicate)
207+
return type(self)(pa.chunked_array(result))
208+
140209
def _str_removesuffix(self, suffix: str):
141210
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
142211
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
@@ -228,6 +297,20 @@ def _str_contains(
228297
result = result.fill_null(na)
229298
return self._convert_bool_result(result)
230299

300+
def _str_match(
301+
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
302+
):
303+
if not pat.startswith("^"):
304+
pat = f"^{pat}"
305+
return self._str_contains(pat, case, flags, na, regex=True)
306+
307+
def _str_fullmatch(
308+
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
309+
):
310+
if not pat.endswith("$") or pat.endswith("\\$"):
311+
pat = f"{pat}$"
312+
return self._str_match(pat, case, flags, na)
313+
231314
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
232315
if (
233316
pa_version_under13p0

pandas/core/arrays/arrow/array.py

+1-85
Original file line numberDiff line numberDiff line change
@@ -1999,7 +1999,7 @@ def _rank(
19991999
"""
20002000
See Series.rank.__doc__.
20012001
"""
2002-
return type(self)(
2002+
return self._convert_int_result(
20032003
self._rank_calc(
20042004
axis=axis,
20052005
method=method,
@@ -2323,57 +2323,13 @@ def _str_count(self, pat: str, flags: int = 0) -> Self:
23232323
raise NotImplementedError(f"count not implemented with {flags=}")
23242324
return type(self)(pc.count_substring_regex(self._pa_array, pat))
23252325

2326-
def _result_converter(self, result):
2327-
return type(self)(result)
2328-
2329-
def _str_replace(
2330-
self,
2331-
pat: str | re.Pattern,
2332-
repl: str | Callable,
2333-
n: int = -1,
2334-
case: bool = True,
2335-
flags: int = 0,
2336-
regex: bool = True,
2337-
) -> Self:
2338-
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
2339-
raise NotImplementedError(
2340-
"replace is not supported with a re.Pattern, callable repl, "
2341-
"case=False, or flags!=0"
2342-
)
2343-
2344-
func = pc.replace_substring_regex if regex else pc.replace_substring
2345-
# https://github.com/apache/arrow/issues/39149
2346-
# GH 56404, unexpected behavior with negative max_replacements with pyarrow.
2347-
pa_max_replacements = None if n < 0 else n
2348-
result = func(
2349-
self._pa_array,
2350-
pattern=pat,
2351-
replacement=repl,
2352-
max_replacements=pa_max_replacements,
2353-
)
2354-
return type(self)(result)
2355-
23562326
def _str_repeat(self, repeats: int | Sequence[int]) -> Self:
23572327
if not isinstance(repeats, int):
23582328
raise NotImplementedError(
23592329
f"repeat is not implemented when repeats is {type(repeats).__name__}"
23602330
)
23612331
return type(self)(pc.binary_repeat(self._pa_array, repeats))
23622332

2363-
def _str_match(
2364-
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
2365-
) -> Self:
2366-
if not pat.startswith("^"):
2367-
pat = f"^{pat}"
2368-
return self._str_contains(pat, case, flags, na, regex=True)
2369-
2370-
def _str_fullmatch(
2371-
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
2372-
) -> Self:
2373-
if not pat.endswith("$") or pat.endswith("\\$"):
2374-
pat = f"{pat}$"
2375-
return self._str_match(pat, case, flags, na)
2376-
23772333
def _str_join(self, sep: str) -> Self:
23782334
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
23792335
self._pa_array.type
@@ -2394,46 +2350,6 @@ def _str_rpartition(self, sep: str, expand: bool) -> Self:
23942350
result = self._apply_elementwise(predicate)
23952351
return type(self)(pa.chunked_array(result))
23962352

2397-
def _str_len(self) -> Self:
2398-
return type(self)(pc.utf8_length(self._pa_array))
2399-
2400-
def _str_lower(self) -> Self:
2401-
return type(self)(pc.utf8_lower(self._pa_array))
2402-
2403-
def _str_upper(self) -> Self:
2404-
return type(self)(pc.utf8_upper(self._pa_array))
2405-
2406-
def _str_strip(self, to_strip=None) -> Self:
2407-
if to_strip is None:
2408-
result = pc.utf8_trim_whitespace(self._pa_array)
2409-
else:
2410-
result = pc.utf8_trim(self._pa_array, characters=to_strip)
2411-
return type(self)(result)
2412-
2413-
def _str_lstrip(self, to_strip=None) -> Self:
2414-
if to_strip is None:
2415-
result = pc.utf8_ltrim_whitespace(self._pa_array)
2416-
else:
2417-
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
2418-
return type(self)(result)
2419-
2420-
def _str_rstrip(self, to_strip=None) -> Self:
2421-
if to_strip is None:
2422-
result = pc.utf8_rtrim_whitespace(self._pa_array)
2423-
else:
2424-
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
2425-
return type(self)(result)
2426-
2427-
def _str_removeprefix(self, prefix: str):
2428-
if not pa_version_under13p0:
2429-
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
2430-
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
2431-
result = pc.if_else(starts_with, removed, self._pa_array)
2432-
return type(self)(result)
2433-
predicate = lambda val: val.removeprefix(prefix)
2434-
result = self._apply_elementwise(predicate)
2435-
return type(self)(pa.chunked_array(result))
2436-
24372353
def _str_casefold(self) -> Self:
24382354
predicate = lambda val: val.casefold()
24392355
result = self._apply_elementwise(predicate)

pandas/core/arrays/string_arrow.py

+19-87
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@
5050

5151
from pandas._typing import (
5252
ArrayLike,
53-
AxisInt,
5453
Dtype,
5554
NpDtype,
56-
Scalar,
5755
Self,
5856
npt,
5957
)
@@ -290,6 +288,20 @@ def astype(self, dtype, copy: bool = True):
290288
_str_startswith = ArrowStringArrayMixin._str_startswith
291289
_str_endswith = ArrowStringArrayMixin._str_endswith
292290
_str_pad = ArrowStringArrayMixin._str_pad
291+
_str_match = ArrowStringArrayMixin._str_match
292+
_str_fullmatch = ArrowStringArrayMixin._str_fullmatch
293+
_str_lower = ArrowStringArrayMixin._str_lower
294+
_str_upper = ArrowStringArrayMixin._str_upper
295+
_str_strip = ArrowStringArrayMixin._str_strip
296+
_str_lstrip = ArrowStringArrayMixin._str_lstrip
297+
_str_rstrip = ArrowStringArrayMixin._str_rstrip
298+
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
299+
_str_get = ArrowStringArrayMixin._str_get
300+
_str_capitalize = ArrowStringArrayMixin._str_capitalize
301+
_str_title = ArrowStringArrayMixin._str_title
302+
_str_swapcase = ArrowStringArrayMixin._str_swapcase
303+
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace
304+
_str_len = ArrowStringArrayMixin._str_len
293305
_str_slice = ArrowStringArrayMixin._str_slice
294306

295307
def _str_contains(
@@ -323,73 +335,21 @@ def _str_replace(
323335
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
324336
return super()._str_replace(pat, repl, n, case, flags, regex)
325337

326-
return ArrowExtensionArray._str_replace(self, pat, repl, n, case, flags, regex)
338+
return ArrowStringArrayMixin._str_replace(
339+
self, pat, repl, n, case, flags, regex
340+
)
327341

328342
def _str_repeat(self, repeats: int | Sequence[int]):
329343
if not isinstance(repeats, int):
330344
return super()._str_repeat(repeats)
331345
else:
332-
return type(self)(pc.binary_repeat(self._pa_array, repeats))
333-
334-
def _str_match(
335-
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
336-
):
337-
if not pat.startswith("^"):
338-
pat = f"^{pat}"
339-
return self._str_contains(pat, case, flags, na, regex=True)
340-
341-
def _str_fullmatch(
342-
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
343-
):
344-
if not pat.endswith("$") or pat.endswith("\\$"):
345-
pat = f"{pat}$"
346-
return self._str_match(pat, case, flags, na)
347-
348-
def _str_len(self):
349-
result = pc.utf8_length(self._pa_array)
350-
return self._convert_int_result(result)
351-
352-
def _str_lower(self) -> Self:
353-
return type(self)(pc.utf8_lower(self._pa_array))
354-
355-
def _str_upper(self) -> Self:
356-
return type(self)(pc.utf8_upper(self._pa_array))
357-
358-
def _str_strip(self, to_strip=None) -> Self:
359-
if to_strip is None:
360-
result = pc.utf8_trim_whitespace(self._pa_array)
361-
else:
362-
result = pc.utf8_trim(self._pa_array, characters=to_strip)
363-
return type(self)(result)
364-
365-
def _str_lstrip(self, to_strip=None) -> Self:
366-
if to_strip is None:
367-
result = pc.utf8_ltrim_whitespace(self._pa_array)
368-
else:
369-
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
370-
return type(self)(result)
371-
372-
def _str_rstrip(self, to_strip=None) -> Self:
373-
if to_strip is None:
374-
result = pc.utf8_rtrim_whitespace(self._pa_array)
375-
else:
376-
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
377-
return type(self)(result)
346+
return ArrowExtensionArray._str_repeat(self, repeats=repeats)
378347

379348
def _str_removeprefix(self, prefix: str):
380349
if not pa_version_under13p0:
381-
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
382-
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
383-
result = pc.if_else(starts_with, removed, self._pa_array)
384-
return type(self)(result)
350+
return ArrowStringArrayMixin._str_removeprefix(self, prefix)
385351
return super()._str_removeprefix(prefix)
386352

387-
def _str_removesuffix(self, suffix: str):
388-
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
389-
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
390-
result = pc.if_else(ends_with, removed, self._pa_array)
391-
return type(self)(result)
392-
393353
def _str_count(self, pat: str, flags: int = 0):
394354
if flags:
395355
return super()._str_count(pat, flags)
@@ -456,28 +416,6 @@ def _reduce(
456416
else:
457417
return result
458418

459-
def _rank(
460-
self,
461-
*,
462-
axis: AxisInt = 0,
463-
method: str = "average",
464-
na_option: str = "keep",
465-
ascending: bool = True,
466-
pct: bool = False,
467-
):
468-
"""
469-
See Series.rank.__doc__.
470-
"""
471-
return self._convert_int_result(
472-
self._rank_calc(
473-
axis=axis,
474-
method=method,
475-
na_option=na_option,
476-
ascending=ascending,
477-
pct=pct,
478-
)
479-
)
480-
481419
def value_counts(self, dropna: bool = True) -> Series:
482420
result = super().value_counts(dropna=dropna)
483421
if self.dtype.na_value is np.nan:
@@ -499,9 +437,3 @@ def _cmp_method(self, other, op):
499437

500438
class ArrowStringArrayNumpySemantics(ArrowStringArray):
501439
_na_value = np.nan
502-
_str_get = ArrowStringArrayMixin._str_get
503-
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
504-
_str_capitalize = ArrowStringArrayMixin._str_capitalize
505-
_str_title = ArrowStringArrayMixin._str_title
506-
_str_swapcase = ArrowStringArrayMixin._str_swapcase
507-
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace

0 commit comments

Comments
 (0)