Skip to content

Commit b4899fa

Browse files
mroeschkeYi Wei
authored and
Yi Wei
committed
ENH: Implement more str accessor methods for ArrowDtype (pandas-dev#52614)
* Add more str arrow functions * Finish functions * finish methods and add tests * Finish implementing * Fix >3.8 compat * Create helper function
1 parent 8662194 commit b4899fa

File tree

4 files changed

+214
-75
lines changed

4 files changed

+214
-75
lines changed

doc/source/whatsnew/v2.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Other
5050
- :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
5151
- :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
5252
- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
53+
- Implemented most ``str`` accessor methods for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
5354

5455
.. ---------------------------------------------------------------------------
5556
.. _whatsnew_201.contributors:

pandas/core/arrays/arrow/array.py

+83-50
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

3+
import functools
34
import operator
45
import re
6+
import sys
7+
import textwrap
58
from typing import (
69
TYPE_CHECKING,
710
Any,
@@ -10,6 +13,7 @@
1013
Sequence,
1114
cast,
1215
)
16+
import unicodedata
1317

1418
import numpy as np
1519

@@ -1749,6 +1753,16 @@ def _groupby_op(
17491753
return result
17501754
return type(self)._from_sequence(result, copy=False)
17511755

1756+
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
1757+
"""Apply a callable to each element while maintaining the chunking structure."""
1758+
return [
1759+
[
1760+
None if val is None else func(val)
1761+
for val in chunk.to_numpy(zero_copy_only=False)
1762+
]
1763+
for chunk in self._pa_array.iterchunks()
1764+
]
1765+
17521766
def _str_count(self, pat: str, flags: int = 0):
17531767
if flags:
17541768
raise NotImplementedError(f"count not implemented with {flags=}")
@@ -1882,14 +1896,14 @@ def _str_join(self, sep: str):
18821896
return type(self)(pc.binary_join(self._pa_array, sep))
18831897

18841898
def _str_partition(self, sep: str, expand: bool):
1885-
raise NotImplementedError(
1886-
"str.partition not supported with pd.ArrowDtype(pa.string())."
1887-
)
1899+
predicate = lambda val: val.partition(sep)
1900+
result = self._apply_elementwise(predicate)
1901+
return type(self)(pa.chunked_array(result))
18881902

18891903
def _str_rpartition(self, sep: str, expand: bool):
1890-
raise NotImplementedError(
1891-
"str.rpartition not supported with pd.ArrowDtype(pa.string())."
1892-
)
1904+
predicate = lambda val: val.rpartition(sep)
1905+
result = self._apply_elementwise(predicate)
1906+
return type(self)(pa.chunked_array(result))
18931907

18941908
def _str_slice(
18951909
self, start: int | None = None, stop: int | None = None, step: int | None = None
@@ -1978,14 +1992,21 @@ def _str_rstrip(self, to_strip=None):
19781992
return type(self)(result)
19791993

19801994
def _str_removeprefix(self, prefix: str):
1981-
raise NotImplementedError(
1982-
"str.removeprefix not supported with pd.ArrowDtype(pa.string())."
1983-
)
19841995
# TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed
19851996
# starts_with = pc.starts_with(self._pa_array, pattern=prefix)
19861997
# removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
19871998
# result = pc.if_else(starts_with, removed, self._pa_array)
19881999
# return type(self)(result)
2000+
if sys.version_info < (3, 9):
2001+
# NOTE pyupgrade will remove this when we run it with --py39-plus
2002+
# so don't remove the unnecessary `else` statement below
2003+
from pandas.util._str_methods import removeprefix
2004+
2005+
predicate = functools.partial(removeprefix, prefix=prefix)
2006+
else:
2007+
predicate = lambda val: val.removeprefix(prefix)
2008+
result = self._apply_elementwise(predicate)
2009+
return type(self)(pa.chunked_array(result))
19892010

19902011
def _str_removesuffix(self, suffix: str):
19912012
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
@@ -1994,49 +2015,59 @@ def _str_removesuffix(self, suffix: str):
19942015
return type(self)(result)
19952016

19962017
def _str_casefold(self):
1997-
raise NotImplementedError(
1998-
"str.casefold not supported with pd.ArrowDtype(pa.string())."
1999-
)
2018+
predicate = lambda val: val.casefold()
2019+
result = self._apply_elementwise(predicate)
2020+
return type(self)(pa.chunked_array(result))
20002021

2001-
def _str_encode(self, encoding, errors: str = "strict"):
2002-
raise NotImplementedError(
2003-
"str.encode not supported with pd.ArrowDtype(pa.string())."
2004-
)
2022+
def _str_encode(self, encoding: str, errors: str = "strict"):
2023+
predicate = lambda val: val.encode(encoding, errors)
2024+
result = self._apply_elementwise(predicate)
2025+
return type(self)(pa.chunked_array(result))
20052026

20062027
def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
20072028
raise NotImplementedError(
20082029
"str.extract not supported with pd.ArrowDtype(pa.string())."
20092030
)
20102031

2011-
def _str_findall(self, pat, flags: int = 0):
2012-
raise NotImplementedError(
2013-
"str.findall not supported with pd.ArrowDtype(pa.string())."
2014-
)
2032+
def _str_findall(self, pat: str, flags: int = 0):
2033+
regex = re.compile(pat, flags=flags)
2034+
predicate = lambda val: regex.findall(val)
2035+
result = self._apply_elementwise(predicate)
2036+
return type(self)(pa.chunked_array(result))
20152037

20162038
def _str_get_dummies(self, sep: str = "|"):
2017-
raise NotImplementedError(
2018-
"str.get_dummies not supported with pd.ArrowDtype(pa.string())."
2019-
)
2020-
2021-
def _str_index(self, sub, start: int = 0, end=None):
2022-
raise NotImplementedError(
2023-
"str.index not supported with pd.ArrowDtype(pa.string())."
2024-
)
2025-
2026-
def _str_rindex(self, sub, start: int = 0, end=None):
2027-
raise NotImplementedError(
2028-
"str.rindex not supported with pd.ArrowDtype(pa.string())."
2029-
)
2030-
2031-
def _str_normalize(self, form):
2032-
raise NotImplementedError(
2033-
"str.normalize not supported with pd.ArrowDtype(pa.string())."
2034-
)
2035-
2036-
def _str_rfind(self, sub, start: int = 0, end=None):
2037-
raise NotImplementedError(
2038-
"str.rfind not supported with pd.ArrowDtype(pa.string())."
2039-
)
2039+
split = pc.split_pattern(self._pa_array, sep).combine_chunks()
2040+
uniques = split.flatten().unique()
2041+
uniques_sorted = uniques.take(pa.compute.array_sort_indices(uniques))
2042+
result_data = []
2043+
for lst in split.to_pylist():
2044+
if lst is None:
2045+
result_data.append([False] * len(uniques_sorted))
2046+
else:
2047+
res = pc.is_in(uniques_sorted, pa.array(set(lst)))
2048+
result_data.append(res.to_pylist())
2049+
result = type(self)(pa.array(result_data))
2050+
return result, uniques_sorted.to_pylist()
2051+
2052+
def _str_index(self, sub: str, start: int = 0, end: int | None = None):
2053+
predicate = lambda val: val.index(sub, start, end)
2054+
result = self._apply_elementwise(predicate)
2055+
return type(self)(pa.chunked_array(result))
2056+
2057+
def _str_rindex(self, sub: str, start: int = 0, end: int | None = None):
2058+
predicate = lambda val: val.rindex(sub, start, end)
2059+
result = self._apply_elementwise(predicate)
2060+
return type(self)(pa.chunked_array(result))
2061+
2062+
def _str_normalize(self, form: str):
2063+
predicate = lambda val: unicodedata.normalize(form, val)
2064+
result = self._apply_elementwise(predicate)
2065+
return type(self)(pa.chunked_array(result))
2066+
2067+
def _str_rfind(self, sub: str, start: int = 0, end=None):
2068+
predicate = lambda val: val.rfind(sub, start, end)
2069+
result = self._apply_elementwise(predicate)
2070+
return type(self)(pa.chunked_array(result))
20402071

20412072
def _str_split(
20422073
self,
@@ -2060,15 +2091,17 @@ def _str_rsplit(self, pat: str | None = None, n: int | None = -1):
20602091
pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True)
20612092
)
20622093

2063-
def _str_translate(self, table):
2064-
raise NotImplementedError(
2065-
"str.translate not supported with pd.ArrowDtype(pa.string())."
2066-
)
2094+
def _str_translate(self, table: dict[int, str]):
2095+
predicate = lambda val: val.translate(table)
2096+
result = self._apply_elementwise(predicate)
2097+
return type(self)(pa.chunked_array(result))
20672098

20682099
def _str_wrap(self, width: int, **kwargs):
2069-
raise NotImplementedError(
2070-
"str.wrap not supported with pd.ArrowDtype(pa.string())."
2071-
)
2100+
kwargs["width"] = width
2101+
tw = textwrap.TextWrapper(**kwargs)
2102+
predicate = lambda val: "\n".join(tw.wrap(val))
2103+
result = self._apply_elementwise(predicate)
2104+
return type(self)(pa.chunked_array(result))
20722105

20732106
@property
20742107
def _dt_year(self):

pandas/core/strings/accessor.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def _wrap_result(
267267
if expand is None:
268268
# infer from ndim if expand is not specified
269269
expand = result.ndim != 1
270-
271270
elif expand is True and not isinstance(self._orig, ABCIndex):
272271
# required when expand=True is explicitly specified
273272
# not needed when inferred
@@ -280,10 +279,15 @@ def _wrap_result(
280279
result._pa_array.combine_chunks().value_lengths()
281280
).as_py()
282281
if result.isna().any():
282+
# ArrowExtensionArray.fillna doesn't work for list scalars
283283
result._pa_array = result._pa_array.fill_null([None] * max_len)
284+
if name is not None:
285+
labels = name
286+
else:
287+
labels = range(max_len)
284288
result = {
285-
i: ArrowExtensionArray(pa.array(res))
286-
for i, res in enumerate(zip(*result.tolist()))
289+
label: ArrowExtensionArray(pa.array(res))
290+
for label, res in zip(labels, (zip(*result.tolist())))
287291
}
288292
elif is_object_dtype(result):
289293

0 commit comments

Comments
 (0)