Skip to content

Backport PR #52614: ENH: Implement more str accessor methods for ArrowDtype #52842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Other
- :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
- :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
- Implemented most ``str`` accessor methods for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)

.. ---------------------------------------------------------------------------
.. _whatsnew_201.contributors:
Expand Down
137 changes: 85 additions & 52 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

from copy import deepcopy
import functools
import operator
import re
import sys
import textwrap
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -12,6 +15,7 @@
TypeVar,
cast,
)
import unicodedata

import numpy as np

Expand Down Expand Up @@ -1655,6 +1659,16 @@ def _replace_with_mask(
result[mask] = replacements
return pa.array(result, type=values.type, from_pandas=True)

def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
"""Apply a callable to each element while maintaining the chunking structure."""
return [
[
None if val is None else func(val)
for val in chunk.to_numpy(zero_copy_only=False)
]
for chunk in self._data.iterchunks()
]

def _str_count(self, pat: str, flags: int = 0):
if flags:
raise NotImplementedError(f"count not implemented with {flags=}")
Expand Down Expand Up @@ -1788,14 +1802,14 @@ def _str_join(self, sep: str):
return type(self)(pc.binary_join(self._data, sep))

def _str_partition(self, sep: str, expand: bool):
raise NotImplementedError(
"str.partition not supported with pd.ArrowDtype(pa.string())."
)
predicate = lambda val: val.partition(sep)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_rpartition(self, sep: str, expand: bool):
raise NotImplementedError(
"str.rpartition not supported with pd.ArrowDtype(pa.string())."
)
predicate = lambda val: val.rpartition(sep)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_slice(
self, start: int | None = None, stop: int | None = None, step: int | None = None
Expand Down Expand Up @@ -1884,14 +1898,21 @@ def _str_rstrip(self, to_strip=None):
return type(self)(result)

def _str_removeprefix(self, prefix: str):
raise NotImplementedError(
"str.removeprefix not supported with pd.ArrowDtype(pa.string())."
)
# TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed
# starts_with = pc.starts_with(self._data, pattern=prefix)
# removed = pc.utf8_slice_codeunits(self._data, len(prefix))
# result = pc.if_else(starts_with, removed, self._data)
# return type(self)(result)
if sys.version_info < (3, 9):
# NOTE pyupgrade will remove this when we run it with --py39-plus
# so don't remove the unnecessary `else` statement below
from pandas.util._str_methods import removeprefix

predicate = functools.partial(removeprefix, prefix=prefix)
else:
predicate = lambda val: val.removeprefix(prefix)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._data, pattern=suffix)
Expand All @@ -1900,49 +1921,59 @@ def _str_removesuffix(self, suffix: str):
return type(self)(result)

def _str_casefold(self):
raise NotImplementedError(
"str.casefold not supported with pd.ArrowDtype(pa.string())."
)
predicate = lambda val: val.casefold()
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_encode(self, encoding, errors: str = "strict"):
raise NotImplementedError(
"str.encode not supported with pd.ArrowDtype(pa.string())."
)
def _str_encode(self, encoding: str, errors: str = "strict"):
predicate = lambda val: val.encode(encoding, errors)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
raise NotImplementedError(
"str.extract not supported with pd.ArrowDtype(pa.string())."
)

def _str_findall(self, pat, flags: int = 0):
raise NotImplementedError(
"str.findall not supported with pd.ArrowDtype(pa.string())."
)
def _str_findall(self, pat: str, flags: int = 0):
regex = re.compile(pat, flags=flags)
predicate = lambda val: regex.findall(val)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_get_dummies(self, sep: str = "|"):
raise NotImplementedError(
"str.get_dummies not supported with pd.ArrowDtype(pa.string())."
)

def _str_index(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.index not supported with pd.ArrowDtype(pa.string())."
)

def _str_rindex(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.rindex not supported with pd.ArrowDtype(pa.string())."
)

def _str_normalize(self, form):
raise NotImplementedError(
"str.normalize not supported with pd.ArrowDtype(pa.string())."
)

def _str_rfind(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.rfind not supported with pd.ArrowDtype(pa.string())."
)
split = pc.split_pattern(self._data, sep).combine_chunks()
uniques = split.flatten().unique()
uniques_sorted = uniques.take(pa.compute.array_sort_indices(uniques))
result_data = []
for lst in split.to_pylist():
if lst is None:
result_data.append([False] * len(uniques_sorted))
else:
res = pc.is_in(uniques_sorted, pa.array(set(lst)))
result_data.append(res.to_pylist())
result = type(self)(pa.array(result_data))
return result, uniques_sorted.to_pylist()

def _str_index(self, sub: str, start: int = 0, end: int | None = None):
predicate = lambda val: val.index(sub, start, end)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_rindex(self, sub: str, start: int = 0, end: int | None = None):
predicate = lambda val: val.rindex(sub, start, end)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_normalize(self, form: str):
predicate = lambda val: unicodedata.normalize(form, val)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_rfind(self, sub: str, start: int = 0, end=None):
predicate = lambda val: val.rfind(sub, start, end)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_split(
self,
Expand All @@ -1964,15 +1995,17 @@ def _str_rsplit(self, pat: str | None = None, n: int | None = -1):
n = None
return type(self)(pc.split_pattern(self._data, pat, max_splits=n, reverse=True))

def _str_translate(self, table):
raise NotImplementedError(
"str.translate not supported with pd.ArrowDtype(pa.string())."
)

def _str_wrap(self, width, **kwargs):
raise NotImplementedError(
"str.wrap not supported with pd.ArrowDtype(pa.string())."
)
def _str_translate(self, table: dict[int, str]):
predicate = lambda val: val.translate(table)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_wrap(self, width: int, **kwargs):
kwargs["width"] = width
tw = textwrap.TextWrapper(**kwargs)
predicate = lambda val: "\n".join(tw.wrap(val))
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

@property
def _dt_year(self):
Expand Down
10 changes: 7 additions & 3 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def _wrap_result(
if expand is None:
# infer from ndim if expand is not specified
expand = result.ndim != 1

elif expand is True and not isinstance(self._orig, ABCIndex):
# required when expand=True is explicitly specified
# not needed when inferred
Expand All @@ -280,10 +279,15 @@ def _wrap_result(
result._data.combine_chunks().value_lengths()
).as_py()
if result.isna().any():
# ArrowExtensionArray.fillna doesn't work for list scalars
result._data = result._data.fill_null([None] * max_len)
if name is not None:
labels = name
else:
labels = range(max_len)
result = {
i: ArrowExtensionArray(pa.array(res))
for i, res in enumerate(zip(*result.tolist()))
label: ArrowExtensionArray(pa.array(res))
for label, res in zip(labels, (zip(*result.tolist())))
}
elif is_object_dtype(result):

Expand Down
Loading