Skip to content

Commit 7272da7

Browse files
phoflmeeseeksmachine
authored andcommitted
Backport PR pandas-dev#54534: REF: Move methods that can be shared with new string dtype
1 parent 08edd64 commit 7272da7

File tree

2 files changed

+89
-63
lines changed

2 files changed

+89
-63
lines changed
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
from typing import Literal
4+
5+
import numpy as np
6+
7+
from pandas.compat import pa_version_under7p0
8+
9+
if not pa_version_under7p0:
10+
import pyarrow as pa
11+
import pyarrow.compute as pc
12+
13+
14+
class ArrowStringArrayMixin:
15+
_pa_array = None
16+
17+
def __init__(self, *args, **kwargs) -> None:
18+
raise NotImplementedError
19+
20+
def _str_pad(
21+
self,
22+
width: int,
23+
side: Literal["left", "right", "both"] = "left",
24+
fillchar: str = " ",
25+
):
26+
if side == "left":
27+
pa_pad = pc.utf8_lpad
28+
elif side == "right":
29+
pa_pad = pc.utf8_rpad
30+
elif side == "both":
31+
pa_pad = pc.utf8_center
32+
else:
33+
raise ValueError(
34+
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
35+
)
36+
return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar))
37+
38+
def _str_get(self, i: int):
39+
lengths = pc.utf8_length(self._pa_array)
40+
if i >= 0:
41+
out_of_bounds = pc.greater_equal(i, lengths)
42+
start = i
43+
stop = i + 1
44+
step = 1
45+
else:
46+
out_of_bounds = pc.greater(-i, lengths)
47+
start = i
48+
stop = i - 1
49+
step = -1
50+
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
51+
selected = pc.utf8_slice_codeunits(
52+
self._pa_array, start=start, stop=stop, step=step
53+
)
54+
null_value = pa.scalar(
55+
None, type=self._pa_array.type # type: ignore[attr-defined]
56+
)
57+
result = pc.if_else(not_out_of_bounds, selected, null_value)
58+
return type(self)(result)
59+
60+
def _str_slice_replace(
61+
self, start: int | None = None, stop: int | None = None, repl: str | None = None
62+
):
63+
if repl is None:
64+
repl = ""
65+
if start is None:
66+
start = 0
67+
if stop is None:
68+
stop = np.iinfo(np.int64).max
69+
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))
70+
71+
def _str_capitalize(self):
72+
return type(self)(pc.utf8_capitalize(self._pa_array))
73+
74+
def _str_title(self):
75+
return type(self)(pc.utf8_title(self._pa_array))
76+
77+
def _str_swapcase(self):
78+
return type(self)(pc.utf8_swapcase(self._pa_array))
79+
80+
def _str_removesuffix(self, suffix: str):
81+
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
82+
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
83+
result = pc.if_else(ends_with, removed, self._pa_array)
84+
return type(self)(result)

pandas/core/arrays/arrow/array.py

+5-63
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
from pandas.core import roperator
4444
from pandas.core.arraylike import OpsMixin
45+
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
4546
from pandas.core.arrays.base import (
4647
ExtensionArray,
4748
ExtensionArraySupportsAnyAll,
@@ -184,7 +185,10 @@ def to_pyarrow_type(
184185

185186

186187
class ArrowExtensionArray(
187-
OpsMixin, ExtensionArraySupportsAnyAll, BaseStringArrayMethods
188+
OpsMixin,
189+
ExtensionArraySupportsAnyAll,
190+
ArrowStringArrayMixin,
191+
BaseStringArrayMethods,
188192
):
189193
"""
190194
Pandas ExtensionArray backed by a PyArrow ChunkedArray.
@@ -1986,24 +1990,6 @@ def _str_count(self, pat: str, flags: int = 0):
19861990
raise NotImplementedError(f"count not implemented with {flags=}")
19871991
return type(self)(pc.count_substring_regex(self._pa_array, pat))
19881992

1989-
def _str_pad(
1990-
self,
1991-
width: int,
1992-
side: Literal["left", "right", "both"] = "left",
1993-
fillchar: str = " ",
1994-
):
1995-
if side == "left":
1996-
pa_pad = pc.utf8_lpad
1997-
elif side == "right":
1998-
pa_pad = pc.utf8_rpad
1999-
elif side == "both":
2000-
pa_pad = pc.utf8_center
2001-
else:
2002-
raise ValueError(
2003-
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
2004-
)
2005-
return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar))
2006-
20071993
def _str_contains(
20081994
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
20091995
):
@@ -2088,26 +2074,6 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
20882074
)
20892075
return type(self)(result)
20902076

2091-
def _str_get(self, i: int):
2092-
lengths = pc.utf8_length(self._pa_array)
2093-
if i >= 0:
2094-
out_of_bounds = pc.greater_equal(i, lengths)
2095-
start = i
2096-
stop = i + 1
2097-
step = 1
2098-
else:
2099-
out_of_bounds = pc.greater(-i, lengths)
2100-
start = i
2101-
stop = i - 1
2102-
step = -1
2103-
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
2104-
selected = pc.utf8_slice_codeunits(
2105-
self._pa_array, start=start, stop=stop, step=step
2106-
)
2107-
null_value = pa.scalar(None, type=self._pa_array.type)
2108-
result = pc.if_else(not_out_of_bounds, selected, null_value)
2109-
return type(self)(result)
2110-
21112077
def _str_join(self, sep: str):
21122078
if pa.types.is_string(self._pa_array.type):
21132079
result = self._apply_elementwise(list)
@@ -2137,15 +2103,6 @@ def _str_slice(
21372103
pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
21382104
)
21392105

2140-
def _str_slice_replace(
2141-
self, start: int | None = None, stop: int | None = None, repl: str | None = None
2142-
):
2143-
if repl is None:
2144-
repl = ""
2145-
if start is None:
2146-
start = 0
2147-
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))
2148-
21492106
def _str_isalnum(self):
21502107
return type(self)(pc.utf8_is_alnum(self._pa_array))
21512108

@@ -2170,18 +2127,9 @@ def _str_isspace(self):
21702127
def _str_istitle(self):
21712128
return type(self)(pc.utf8_is_title(self._pa_array))
21722129

2173-
def _str_capitalize(self):
2174-
return type(self)(pc.utf8_capitalize(self._pa_array))
2175-
2176-
def _str_title(self):
2177-
return type(self)(pc.utf8_title(self._pa_array))
2178-
21792130
def _str_isupper(self):
21802131
return type(self)(pc.utf8_is_upper(self._pa_array))
21812132

2182-
def _str_swapcase(self):
2183-
return type(self)(pc.utf8_swapcase(self._pa_array))
2184-
21852133
def _str_len(self):
21862134
return type(self)(pc.utf8_length(self._pa_array))
21872135

@@ -2222,12 +2170,6 @@ def _str_removeprefix(self, prefix: str):
22222170
result = self._apply_elementwise(predicate)
22232171
return type(self)(pa.chunked_array(result))
22242172

2225-
def _str_removesuffix(self, suffix: str):
2226-
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
2227-
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
2228-
result = pc.if_else(ends_with, removed, self._pa_array)
2229-
return type(self)(result)
2230-
22312173
def _str_casefold(self):
22322174
predicate = lambda val: val.casefold()
22332175
result = self._apply_elementwise(predicate)

0 commit comments

Comments
 (0)