Skip to content

Commit fa94d3b

Browse files
authored
Backport PR pandas-dev#52614: ENH: Implement more str accessor methods for ArrowDtype (pandas-dev#52842)
1 parent 6d63392 commit fa94d3b

File tree

4 files changed

+216
-77
lines changed

4 files changed

+216
-77
lines changed

doc/source/whatsnew/v2.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Other
5050
- :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
5151
- :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
5252
- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
53+
- Implemented most ``str`` accessor methods for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
5354

5455
.. ---------------------------------------------------------------------------
5556
.. _whatsnew_201.contributors:

pandas/core/arrays/arrow/array.py

+85-52
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

33
from copy import deepcopy
4+
import functools
45
import operator
56
import re
7+
import sys
8+
import textwrap
69
from typing import (
710
TYPE_CHECKING,
811
Any,
@@ -12,6 +15,7 @@
1215
TypeVar,
1316
cast,
1417
)
18+
import unicodedata
1519

1620
import numpy as np
1721

@@ -1655,6 +1659,16 @@ def _replace_with_mask(
16551659
result[mask] = replacements
16561660
return pa.array(result, type=values.type, from_pandas=True)
16571661

1662+
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
1663+
"""Apply a callable to each element while maintaining the chunking structure."""
1664+
return [
1665+
[
1666+
None if val is None else func(val)
1667+
for val in chunk.to_numpy(zero_copy_only=False)
1668+
]
1669+
for chunk in self._data.iterchunks()
1670+
]
1671+
16581672
def _str_count(self, pat: str, flags: int = 0):
16591673
if flags:
16601674
raise NotImplementedError(f"count not implemented with {flags=}")
@@ -1788,14 +1802,14 @@ def _str_join(self, sep: str):
17881802
return type(self)(pc.binary_join(self._data, sep))
17891803

17901804
def _str_partition(self, sep: str, expand: bool):
1791-
raise NotImplementedError(
1792-
"str.partition not supported with pd.ArrowDtype(pa.string())."
1793-
)
1805+
predicate = lambda val: val.partition(sep)
1806+
result = self._apply_elementwise(predicate)
1807+
return type(self)(pa.chunked_array(result))
17941808

17951809
def _str_rpartition(self, sep: str, expand: bool):
1796-
raise NotImplementedError(
1797-
"str.rpartition not supported with pd.ArrowDtype(pa.string())."
1798-
)
1810+
predicate = lambda val: val.rpartition(sep)
1811+
result = self._apply_elementwise(predicate)
1812+
return type(self)(pa.chunked_array(result))
17991813

18001814
def _str_slice(
18011815
self, start: int | None = None, stop: int | None = None, step: int | None = None
@@ -1884,14 +1898,21 @@ def _str_rstrip(self, to_strip=None):
18841898
return type(self)(result)
18851899

18861900
def _str_removeprefix(self, prefix: str):
1887-
raise NotImplementedError(
1888-
"str.removeprefix not supported with pd.ArrowDtype(pa.string())."
1889-
)
18901901
# TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed
18911902
# starts_with = pc.starts_with(self._data, pattern=prefix)
18921903
# removed = pc.utf8_slice_codeunits(self._data, len(prefix))
18931904
# result = pc.if_else(starts_with, removed, self._data)
18941905
# return type(self)(result)
1906+
if sys.version_info < (3, 9):
1907+
# NOTE pyupgrade will remove this when we run it with --py39-plus
1908+
# so don't remove the unnecessary `else` statement below
1909+
from pandas.util._str_methods import removeprefix
1910+
1911+
predicate = functools.partial(removeprefix, prefix=prefix)
1912+
else:
1913+
predicate = lambda val: val.removeprefix(prefix)
1914+
result = self._apply_elementwise(predicate)
1915+
return type(self)(pa.chunked_array(result))
18951916

18961917
def _str_removesuffix(self, suffix: str):
18971918
ends_with = pc.ends_with(self._data, pattern=suffix)
@@ -1900,49 +1921,59 @@ def _str_removesuffix(self, suffix: str):
19001921
return type(self)(result)
19011922

19021923
def _str_casefold(self):
1903-
raise NotImplementedError(
1904-
"str.casefold not supported with pd.ArrowDtype(pa.string())."
1905-
)
1924+
predicate = lambda val: val.casefold()
1925+
result = self._apply_elementwise(predicate)
1926+
return type(self)(pa.chunked_array(result))
19061927

1907-
def _str_encode(self, encoding, errors: str = "strict"):
1908-
raise NotImplementedError(
1909-
"str.encode not supported with pd.ArrowDtype(pa.string())."
1910-
)
1928+
def _str_encode(self, encoding: str, errors: str = "strict"):
1929+
predicate = lambda val: val.encode(encoding, errors)
1930+
result = self._apply_elementwise(predicate)
1931+
return type(self)(pa.chunked_array(result))
19111932

19121933
def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
19131934
raise NotImplementedError(
19141935
"str.extract not supported with pd.ArrowDtype(pa.string())."
19151936
)
19161937

1917-
def _str_findall(self, pat, flags: int = 0):
1918-
raise NotImplementedError(
1919-
"str.findall not supported with pd.ArrowDtype(pa.string())."
1920-
)
1938+
def _str_findall(self, pat: str, flags: int = 0):
1939+
regex = re.compile(pat, flags=flags)
1940+
predicate = lambda val: regex.findall(val)
1941+
result = self._apply_elementwise(predicate)
1942+
return type(self)(pa.chunked_array(result))
19211943

19221944
def _str_get_dummies(self, sep: str = "|"):
1923-
raise NotImplementedError(
1924-
"str.get_dummies not supported with pd.ArrowDtype(pa.string())."
1925-
)
1926-
1927-
def _str_index(self, sub, start: int = 0, end=None):
1928-
raise NotImplementedError(
1929-
"str.index not supported with pd.ArrowDtype(pa.string())."
1930-
)
1931-
1932-
def _str_rindex(self, sub, start: int = 0, end=None):
1933-
raise NotImplementedError(
1934-
"str.rindex not supported with pd.ArrowDtype(pa.string())."
1935-
)
1936-
1937-
def _str_normalize(self, form):
1938-
raise NotImplementedError(
1939-
"str.normalize not supported with pd.ArrowDtype(pa.string())."
1940-
)
1941-
1942-
def _str_rfind(self, sub, start: int = 0, end=None):
1943-
raise NotImplementedError(
1944-
"str.rfind not supported with pd.ArrowDtype(pa.string())."
1945-
)
1945+
split = pc.split_pattern(self._data, sep).combine_chunks()
1946+
uniques = split.flatten().unique()
1947+
uniques_sorted = uniques.take(pa.compute.array_sort_indices(uniques))
1948+
result_data = []
1949+
for lst in split.to_pylist():
1950+
if lst is None:
1951+
result_data.append([False] * len(uniques_sorted))
1952+
else:
1953+
res = pc.is_in(uniques_sorted, pa.array(set(lst)))
1954+
result_data.append(res.to_pylist())
1955+
result = type(self)(pa.array(result_data))
1956+
return result, uniques_sorted.to_pylist()
1957+
1958+
def _str_index(self, sub: str, start: int = 0, end: int | None = None):
1959+
predicate = lambda val: val.index(sub, start, end)
1960+
result = self._apply_elementwise(predicate)
1961+
return type(self)(pa.chunked_array(result))
1962+
1963+
def _str_rindex(self, sub: str, start: int = 0, end: int | None = None):
1964+
predicate = lambda val: val.rindex(sub, start, end)
1965+
result = self._apply_elementwise(predicate)
1966+
return type(self)(pa.chunked_array(result))
1967+
1968+
def _str_normalize(self, form: str):
1969+
predicate = lambda val: unicodedata.normalize(form, val)
1970+
result = self._apply_elementwise(predicate)
1971+
return type(self)(pa.chunked_array(result))
1972+
1973+
def _str_rfind(self, sub: str, start: int = 0, end=None):
1974+
predicate = lambda val: val.rfind(sub, start, end)
1975+
result = self._apply_elementwise(predicate)
1976+
return type(self)(pa.chunked_array(result))
19461977

19471978
def _str_split(
19481979
self,
@@ -1964,15 +1995,17 @@ def _str_rsplit(self, pat: str | None = None, n: int | None = -1):
19641995
n = None
19651996
return type(self)(pc.split_pattern(self._data, pat, max_splits=n, reverse=True))
19661997

1967-
def _str_translate(self, table):
1968-
raise NotImplementedError(
1969-
"str.translate not supported with pd.ArrowDtype(pa.string())."
1970-
)
1971-
1972-
def _str_wrap(self, width, **kwargs):
1973-
raise NotImplementedError(
1974-
"str.wrap not supported with pd.ArrowDtype(pa.string())."
1975-
)
1998+
def _str_translate(self, table: dict[int, str]):
1999+
predicate = lambda val: val.translate(table)
2000+
result = self._apply_elementwise(predicate)
2001+
return type(self)(pa.chunked_array(result))
2002+
2003+
def _str_wrap(self, width: int, **kwargs):
2004+
kwargs["width"] = width
2005+
tw = textwrap.TextWrapper(**kwargs)
2006+
predicate = lambda val: "\n".join(tw.wrap(val))
2007+
result = self._apply_elementwise(predicate)
2008+
return type(self)(pa.chunked_array(result))
19762009

19772010
@property
19782011
def _dt_year(self):

pandas/core/strings/accessor.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def _wrap_result(
267267
if expand is None:
268268
# infer from ndim if expand is not specified
269269
expand = result.ndim != 1
270-
271270
elif expand is True and not isinstance(self._orig, ABCIndex):
272271
# required when expand=True is explicitly specified
273272
# not needed when inferred
@@ -280,10 +279,15 @@ def _wrap_result(
280279
result._data.combine_chunks().value_lengths()
281280
).as_py()
282281
if result.isna().any():
282+
# ArrowExtensionArray.fillna doesn't work for list scalars
283283
result._data = result._data.fill_null([None] * max_len)
284+
if name is not None:
285+
labels = name
286+
else:
287+
labels = range(max_len)
284288
result = {
285-
i: ArrowExtensionArray(pa.array(res))
286-
for i, res in enumerate(zip(*result.tolist()))
289+
label: ArrowExtensionArray(pa.array(res))
290+
for label, res in zip(labels, (zip(*result.tolist())))
287291
}
288292
elif is_object_dtype(result):
289293

0 commit comments

Comments
 (0)