Skip to content

Commit c8eadfd

Browse files
jbrockmendeljorisvandenbossche
authored andcommitted
REF (string): de-duplicate ArrowStringArray methods (pandas-dev#59555)
1 parent 26a0d56 commit c8eadfd

File tree

3 files changed

+108
-174
lines changed

3 files changed

+108
-174
lines changed

pandas/core/arrays/_arrow_string_mixins.py

+88-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from functools import partial
4+
import re
45
from typing import (
56
TYPE_CHECKING,
67
Any,
@@ -25,7 +26,10 @@
2526
if TYPE_CHECKING:
2627
from collections.abc import Callable
2728

28-
from pandas._typing import Scalar
29+
from pandas._typing import (
30+
Scalar,
31+
Self,
32+
)
2933

3034

3135
class ArrowStringArrayMixin:
@@ -45,6 +49,37 @@ def _convert_int_result(self, result):
4549
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
4650
raise NotImplementedError
4751

52+
def _str_len(self):
53+
result = pc.utf8_length(self._pa_array)
54+
return self._convert_int_result(result)
55+
56+
def _str_lower(self) -> Self:
57+
return type(self)(pc.utf8_lower(self._pa_array))
58+
59+
def _str_upper(self) -> Self:
60+
return type(self)(pc.utf8_upper(self._pa_array))
61+
62+
def _str_strip(self, to_strip=None) -> Self:
63+
if to_strip is None:
64+
result = pc.utf8_trim_whitespace(self._pa_array)
65+
else:
66+
result = pc.utf8_trim(self._pa_array, characters=to_strip)
67+
return type(self)(result)
68+
69+
def _str_lstrip(self, to_strip=None) -> Self:
70+
if to_strip is None:
71+
result = pc.utf8_ltrim_whitespace(self._pa_array)
72+
else:
73+
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
74+
return type(self)(result)
75+
76+
def _str_rstrip(self, to_strip=None) -> Self:
77+
if to_strip is None:
78+
result = pc.utf8_rtrim_whitespace(self._pa_array)
79+
else:
80+
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
81+
return type(self)(result)
82+
4883
def _str_pad(
4984
self,
5085
width: int,
@@ -125,7 +160,34 @@ def _str_slice_replace(
125160
stop = np.iinfo(np.int64).max
126161
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))
127162

128-
def _str_capitalize(self):
163+
def _str_replace(
164+
self,
165+
pat: str | re.Pattern,
166+
repl: str | Callable,
167+
n: int = -1,
168+
case: bool = True,
169+
flags: int = 0,
170+
regex: bool = True,
171+
) -> Self:
172+
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
173+
raise NotImplementedError(
174+
"replace is not supported with a re.Pattern, callable repl, "
175+
"case=False, or flags!=0"
176+
)
177+
178+
func = pc.replace_substring_regex if regex else pc.replace_substring
179+
# https://github.com/apache/arrow/issues/39149
180+
# GH 56404, unexpected behavior with negative max_replacements with pyarrow.
181+
pa_max_replacements = None if n < 0 else n
182+
result = func(
183+
self._pa_array,
184+
pattern=pat,
185+
replacement=repl,
186+
max_replacements=pa_max_replacements,
187+
)
188+
return type(self)(result)
189+
190+
def _str_capitalize(self) -> Self:
129191
return type(self)(pc.utf8_capitalize(self._pa_array))
130192

131193
def _str_title(self):
@@ -134,6 +196,16 @@ def _str_title(self):
134196
def _str_swapcase(self):
135197
return type(self)(pc.utf8_swapcase(self._pa_array))
136198

199+
def _str_removeprefix(self, prefix: str):
200+
if not pa_version_under13p0:
201+
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
202+
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
203+
result = pc.if_else(starts_with, removed, self._pa_array)
204+
return type(self)(result)
205+
predicate = lambda val: val.removeprefix(prefix)
206+
result = self._apply_elementwise(predicate)
207+
return type(self)(pa.chunked_array(result))
208+
137209
def _str_removesuffix(self, suffix: str):
138210
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
139211
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
@@ -225,6 +297,20 @@ def _str_contains(
225297
result = result.fill_null(na)
226298
return self._convert_bool_result(result)
227299

300+
def _str_match(
301+
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
302+
):
303+
if not pat.startswith("^"):
304+
pat = f"^{pat}"
305+
return self._str_contains(pat, case, flags, na, regex=True)
306+
307+
def _str_fullmatch(
308+
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
309+
):
310+
if not pat.endswith("$") or pat.endswith("\\$"):
311+
pat = f"{pat}$"
312+
return self._str_match(pat, case, flags, na)
313+
228314
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
229315
if (
230316
pa_version_under13p0

pandas/core/arrays/arrow/array.py

+1-85
Original file line numberDiff line numberDiff line change
@@ -1989,7 +1989,7 @@ def _rank(
19891989
"""
19901990
See Series.rank.__doc__.
19911991
"""
1992-
return type(self)(
1992+
return self._convert_int_result(
19931993
self._rank_calc(
19941994
axis=axis,
19951995
method=method,
@@ -2296,36 +2296,6 @@ def _str_count(self, pat: str, flags: int = 0):
22962296
raise NotImplementedError(f"count not implemented with {flags=}")
22972297
return type(self)(pc.count_substring_regex(self._pa_array, pat))
22982298

2299-
def _result_converter(self, result):
2300-
return type(self)(result)
2301-
2302-
def _str_replace(
2303-
self,
2304-
pat: str | re.Pattern,
2305-
repl: str | Callable,
2306-
n: int = -1,
2307-
case: bool = True,
2308-
flags: int = 0,
2309-
regex: bool = True,
2310-
):
2311-
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
2312-
raise NotImplementedError(
2313-
"replace is not supported with a re.Pattern, callable repl, "
2314-
"case=False, or flags!=0"
2315-
)
2316-
2317-
func = pc.replace_substring_regex if regex else pc.replace_substring
2318-
# https://github.com/apache/arrow/issues/39149
2319-
# GH 56404, unexpected behavior with negative max_replacements with pyarrow.
2320-
pa_max_replacements = None if n < 0 else n
2321-
result = func(
2322-
self._pa_array,
2323-
pattern=pat,
2324-
replacement=repl,
2325-
max_replacements=pa_max_replacements,
2326-
)
2327-
return type(self)(result)
2328-
23292299
def _str_repeat(self, repeats: int | Sequence[int]):
23302300
if not isinstance(repeats, int):
23312301
raise NotImplementedError(
@@ -2334,20 +2304,6 @@ def _str_repeat(self, repeats: int | Sequence[int]):
23342304
else:
23352305
return type(self)(pc.binary_repeat(self._pa_array, repeats))
23362306

2337-
def _str_match(
2338-
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
2339-
):
2340-
if not pat.startswith("^"):
2341-
pat = f"^{pat}"
2342-
return self._str_contains(pat, case, flags, na, regex=True)
2343-
2344-
def _str_fullmatch(
2345-
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
2346-
):
2347-
if not pat.endswith("$") or pat.endswith("\\$"):
2348-
pat = f"{pat}$"
2349-
return self._str_match(pat, case, flags, na)
2350-
23512307
def _str_join(self, sep: str):
23522308
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
23532309
self._pa_array.type
@@ -2368,46 +2324,6 @@ def _str_rpartition(self, sep: str, expand: bool):
23682324
result = self._apply_elementwise(predicate)
23692325
return type(self)(pa.chunked_array(result))
23702326

2371-
def _str_len(self):
2372-
return type(self)(pc.utf8_length(self._pa_array))
2373-
2374-
def _str_lower(self):
2375-
return type(self)(pc.utf8_lower(self._pa_array))
2376-
2377-
def _str_upper(self):
2378-
return type(self)(pc.utf8_upper(self._pa_array))
2379-
2380-
def _str_strip(self, to_strip=None):
2381-
if to_strip is None:
2382-
result = pc.utf8_trim_whitespace(self._pa_array)
2383-
else:
2384-
result = pc.utf8_trim(self._pa_array, characters=to_strip)
2385-
return type(self)(result)
2386-
2387-
def _str_lstrip(self, to_strip=None):
2388-
if to_strip is None:
2389-
result = pc.utf8_ltrim_whitespace(self._pa_array)
2390-
else:
2391-
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
2392-
return type(self)(result)
2393-
2394-
def _str_rstrip(self, to_strip=None):
2395-
if to_strip is None:
2396-
result = pc.utf8_rtrim_whitespace(self._pa_array)
2397-
else:
2398-
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
2399-
return type(self)(result)
2400-
2401-
def _str_removeprefix(self, prefix: str):
2402-
if not pa_version_under13p0:
2403-
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
2404-
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
2405-
result = pc.if_else(starts_with, removed, self._pa_array)
2406-
return type(self)(result)
2407-
predicate = lambda val: val.removeprefix(prefix)
2408-
result = self._apply_elementwise(predicate)
2409-
return type(self)(pa.chunked_array(result))
2410-
24112327
def _str_casefold(self):
24122328
predicate = lambda val: val.casefold()
24132329
result = self._apply_elementwise(predicate)

pandas/core/arrays/string_arrow.py

+19-87
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@
4848

4949
from pandas._typing import (
5050
ArrayLike,
51-
AxisInt,
5251
Dtype,
53-
Scalar,
5452
npt,
5553
)
5654

@@ -293,6 +291,20 @@ def _data(self):
293291
_str_startswith = ArrowStringArrayMixin._str_startswith
294292
_str_endswith = ArrowStringArrayMixin._str_endswith
295293
_str_pad = ArrowStringArrayMixin._str_pad
294+
_str_match = ArrowStringArrayMixin._str_match
295+
_str_fullmatch = ArrowStringArrayMixin._str_fullmatch
296+
_str_lower = ArrowStringArrayMixin._str_lower
297+
_str_upper = ArrowStringArrayMixin._str_upper
298+
_str_strip = ArrowStringArrayMixin._str_strip
299+
_str_lstrip = ArrowStringArrayMixin._str_lstrip
300+
_str_rstrip = ArrowStringArrayMixin._str_rstrip
301+
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
302+
_str_get = ArrowStringArrayMixin._str_get
303+
_str_capitalize = ArrowStringArrayMixin._str_capitalize
304+
_str_title = ArrowStringArrayMixin._str_title
305+
_str_swapcase = ArrowStringArrayMixin._str_swapcase
306+
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace
307+
_str_len = ArrowStringArrayMixin._str_len
296308
_str_slice = ArrowStringArrayMixin._str_slice
297309

298310
def _str_contains(
@@ -326,73 +338,21 @@ def _str_replace(
326338
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
327339
return super()._str_replace(pat, repl, n, case, flags, regex)
328340

329-
return ArrowExtensionArray._str_replace(self, pat, repl, n, case, flags, regex)
341+
return ArrowStringArrayMixin._str_replace(
342+
self, pat, repl, n, case, flags, regex
343+
)
330344

331345
def _str_repeat(self, repeats: int | Sequence[int]):
332346
if not isinstance(repeats, int):
333347
return super()._str_repeat(repeats)
334348
else:
335-
return type(self)(pc.binary_repeat(self._pa_array, repeats))
336-
337-
def _str_match(
338-
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
339-
):
340-
if not pat.startswith("^"):
341-
pat = f"^{pat}"
342-
return self._str_contains(pat, case, flags, na, regex=True)
343-
344-
def _str_fullmatch(
345-
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
346-
):
347-
if not pat.endswith("$") or pat.endswith("\\$"):
348-
pat = f"{pat}$"
349-
return self._str_match(pat, case, flags, na)
350-
351-
def _str_len(self):
352-
result = pc.utf8_length(self._pa_array)
353-
return self._convert_int_result(result)
354-
355-
def _str_lower(self):
356-
return type(self)(pc.utf8_lower(self._pa_array))
357-
358-
def _str_upper(self):
359-
return type(self)(pc.utf8_upper(self._pa_array))
360-
361-
def _str_strip(self, to_strip=None):
362-
if to_strip is None:
363-
result = pc.utf8_trim_whitespace(self._pa_array)
364-
else:
365-
result = pc.utf8_trim(self._pa_array, characters=to_strip)
366-
return type(self)(result)
367-
368-
def _str_lstrip(self, to_strip=None):
369-
if to_strip is None:
370-
result = pc.utf8_ltrim_whitespace(self._pa_array)
371-
else:
372-
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
373-
return type(self)(result)
374-
375-
def _str_rstrip(self, to_strip=None):
376-
if to_strip is None:
377-
result = pc.utf8_rtrim_whitespace(self._pa_array)
378-
else:
379-
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
380-
return type(self)(result)
349+
return ArrowExtensionArray._str_repeat(self, repeats=repeats)
381350

382351
def _str_removeprefix(self, prefix: str):
383352
if not pa_version_under13p0:
384-
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
385-
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
386-
result = pc.if_else(starts_with, removed, self._pa_array)
387-
return type(self)(result)
353+
return ArrowStringArrayMixin._str_removeprefix(self, prefix)
388354
return super()._str_removeprefix(prefix)
389355

390-
def _str_removesuffix(self, suffix: str):
391-
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
392-
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
393-
result = pc.if_else(ends_with, removed, self._pa_array)
394-
return type(self)(result)
395-
396356
def _str_count(self, pat: str, flags: int = 0):
397357
if flags:
398358
return super()._str_count(pat, flags)
@@ -449,28 +409,6 @@ def _reduce(
449409
else:
450410
return result
451411

452-
def _rank(
453-
self,
454-
*,
455-
axis: AxisInt = 0,
456-
method: str = "average",
457-
na_option: str = "keep",
458-
ascending: bool = True,
459-
pct: bool = False,
460-
):
461-
"""
462-
See Series.rank.__doc__.
463-
"""
464-
return self._convert_int_result(
465-
self._rank_calc(
466-
axis=axis,
467-
method=method,
468-
na_option=na_option,
469-
ascending=ascending,
470-
pct=pct,
471-
)
472-
)
473-
474412
def value_counts(self, dropna: bool = True) -> Series:
475413
result = super().value_counts(dropna=dropna)
476414
if self.dtype.na_value is np.nan:
@@ -492,9 +430,3 @@ def _cmp_method(self, other, op):
492430

493431
class ArrowStringArrayNumpySemantics(ArrowStringArray):
494432
_na_value = np.nan
495-
_str_get = ArrowStringArrayMixin._str_get
496-
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
497-
_str_capitalize = ArrowStringArrayMixin._str_capitalize
498-
_str_title = ArrowStringArrayMixin._str_title
499-
_str_swapcase = ArrowStringArrayMixin._str_swapcase
500-
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace

0 commit comments

Comments
 (0)