From aafa1cffbb7d7f0397ddc0f7a7dbc7669e4134b6 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 21 Apr 2023 16:41:06 -0700 Subject: [PATCH] Backport PR #52614: ENH: Implement more str accessor methods for ArrowDtype --- doc/source/whatsnew/v2.0.1.rst | 1 + pandas/core/arrays/arrow/array.py | 137 +++++++++++++++---------- pandas/core/strings/accessor.py | 10 +- pandas/tests/extension/test_arrow.py | 145 +++++++++++++++++++++++---- 4 files changed, 216 insertions(+), 77 deletions(-) diff --git a/doc/source/whatsnew/v2.0.1.rst b/doc/source/whatsnew/v2.0.1.rst index 92731426bed03..5c9276d0fbf8a 100644 --- a/doc/source/whatsnew/v2.0.1.rst +++ b/doc/source/whatsnew/v2.0.1.rst @@ -50,6 +50,7 @@ Other - :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`) - :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`) - Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`) +- Implemented most ``str`` accessor methods for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`) .. --------------------------------------------------------------------------- .. _whatsnew_201.contributors: diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index a99c5df5ca025..5303a2447a5bf 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1,8 +1,11 @@ from __future__ import annotations from copy import deepcopy +import functools import operator import re +import sys +import textwrap from typing import ( TYPE_CHECKING, Any, @@ -12,6 +15,7 @@ TypeVar, cast, ) +import unicodedata import numpy as np @@ -1655,6 +1659,16 @@ def _replace_with_mask( result[mask] = replacements return pa.array(result, type=values.type, from_pandas=True) + def _apply_elementwise(self, func: Callable) -> list[list[Any]]: + """Apply a callable to each element while maintaining the chunking structure.""" + return [ + [ + None if val is None else func(val) + for val in chunk.to_numpy(zero_copy_only=False) + ] + for chunk in self._data.iterchunks() + ] + def _str_count(self, pat: str, flags: int = 0): if flags: raise NotImplementedError(f"count not implemented with {flags=}") @@ -1788,14 +1802,14 @@ def _str_join(self, sep: str): return type(self)(pc.binary_join(self._data, sep)) def _str_partition(self, sep: str, expand: bool): - raise NotImplementedError( - "str.partition not supported with pd.ArrowDtype(pa.string())." - ) + predicate = lambda val: val.partition(sep) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) def _str_rpartition(self, sep: str, expand: bool): - raise NotImplementedError( - "str.rpartition not supported with pd.ArrowDtype(pa.string())." - ) + predicate = lambda val: val.rpartition(sep) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) def _str_slice( self, start: int | None = None, stop: int | None = None, step: int | None = None @@ -1884,14 +1898,21 @@ def _str_rstrip(self, to_strip=None): return type(self)(result) def _str_removeprefix(self, prefix: str): - raise NotImplementedError( - "str.removeprefix not supported with pd.ArrowDtype(pa.string())." - ) # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed # starts_with = pc.starts_with(self._data, pattern=prefix) # removed = pc.utf8_slice_codeunits(self._data, len(prefix)) # result = pc.if_else(starts_with, removed, self._data) # return type(self)(result) + if sys.version_info < (3, 9): + # NOTE pyupgrade will remove this when we run it with --py39-plus + # so don't remove the unnecessary `else` statement below + from pandas.util._str_methods import removeprefix + + predicate = functools.partial(removeprefix, prefix=prefix) + else: + predicate = lambda val: val.removeprefix(prefix) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) def _str_removesuffix(self, suffix: str): ends_with = pc.ends_with(self._data, pattern=suffix) @@ -1900,49 +1921,59 @@ def _str_removesuffix(self, suffix: str): return type(self)(result) def _str_casefold(self): - raise NotImplementedError( - "str.casefold not supported with pd.ArrowDtype(pa.string())." - ) + predicate = lambda val: val.casefold() + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) - def _str_encode(self, encoding, errors: str = "strict"): - raise NotImplementedError( - "str.encode not supported with pd.ArrowDtype(pa.string())." - ) + def _str_encode(self, encoding: str, errors: str = "strict"): + predicate = lambda val: val.encode(encoding, errors) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): raise NotImplementedError( "str.extract not supported with pd.ArrowDtype(pa.string())." ) - def _str_findall(self, pat, flags: int = 0): - raise NotImplementedError( - "str.findall not supported with pd.ArrowDtype(pa.string())." - ) + def _str_findall(self, pat: str, flags: int = 0): + regex = re.compile(pat, flags=flags) + predicate = lambda val: regex.findall(val) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) def _str_get_dummies(self, sep: str = "|"): - raise NotImplementedError( - "str.get_dummies not supported with pd.ArrowDtype(pa.string())." - ) - - def _str_index(self, sub, start: int = 0, end=None): - raise NotImplementedError( - "str.index not supported with pd.ArrowDtype(pa.string())." - ) - - def _str_rindex(self, sub, start: int = 0, end=None): - raise NotImplementedError( - "str.rindex not supported with pd.ArrowDtype(pa.string())." - ) - - def _str_normalize(self, form): - raise NotImplementedError( - "str.normalize not supported with pd.ArrowDtype(pa.string())." - ) - - def _str_rfind(self, sub, start: int = 0, end=None): - raise NotImplementedError( - "str.rfind not supported with pd.ArrowDtype(pa.string())." - ) + split = pc.split_pattern(self._data, sep).combine_chunks() + uniques = split.flatten().unique() + uniques_sorted = uniques.take(pa.compute.array_sort_indices(uniques)) + result_data = [] + for lst in split.to_pylist(): + if lst is None: + result_data.append([False] * len(uniques_sorted)) + else: + res = pc.is_in(uniques_sorted, pa.array(set(lst))) + result_data.append(res.to_pylist()) + result = type(self)(pa.array(result_data)) + return result, uniques_sorted.to_pylist() + + def _str_index(self, sub: str, start: int = 0, end: int | None = None): + predicate = lambda val: val.index(sub, start, end) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) + + def _str_rindex(self, sub: str, start: int = 0, end: int | None = None): + predicate = lambda val: val.rindex(sub, start, end) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) + + def _str_normalize(self, form: str): + predicate = lambda val: unicodedata.normalize(form, val) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) + + def _str_rfind(self, sub: str, start: int = 0, end=None): + predicate = lambda val: val.rfind(sub, start, end) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) def _str_split( self, @@ -1964,15 +1995,17 @@ def _str_rsplit(self, pat: str | None = None, n: int | None = -1): n = None return type(self)(pc.split_pattern(self._data, pat, max_splits=n, reverse=True)) - def _str_translate(self, table): - raise NotImplementedError( - "str.translate not supported with pd.ArrowDtype(pa.string())." - ) - - def _str_wrap(self, width, **kwargs): - raise NotImplementedError( - "str.wrap not supported with pd.ArrowDtype(pa.string())." - ) + def _str_translate(self, table: dict[int, str]): + predicate = lambda val: val.translate(table) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) + + def _str_wrap(self, width: int, **kwargs): + kwargs["width"] = width + tw = textwrap.TextWrapper(**kwargs) + predicate = lambda val: "\n".join(tw.wrap(val)) + result = self._apply_elementwise(predicate) + return type(self)(pa.chunked_array(result)) @property def _dt_year(self): diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index e44ccd9aede83..699a32fe0c028 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -267,7 +267,6 @@ def _wrap_result( if expand is None: # infer from ndim if expand is not specified expand = result.ndim != 1 - elif expand is True and not isinstance(self._orig, ABCIndex): # required when expand=True is explicitly specified # not needed when inferred @@ -280,10 +279,15 @@ def _wrap_result( result._data.combine_chunks().value_lengths() ).as_py() if result.isna().any(): + # ArrowExtensionArray.fillna doesn't work for list scalars result._data = result._data.fill_null([None] * max_len) + if name is not None: + labels = name + else: + labels = range(max_len) result = { - i: ArrowExtensionArray(pa.array(res)) - for i, res in enumerate(zip(*result.tolist())) + label: ArrowExtensionArray(pa.array(res)) + for label, res in zip(labels, (zip(*result.tolist()))) } elif is_object_dtype(result): diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index f3654bbb8c334..095b892b3bc78 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2094,6 +2094,7 @@ def test_str_is_functions(value, method, exp): ["swapcase", "AbC Def"], ["lower", "abc def"], ["upper", "ABC DEF"], + ["casefold", "abc def"], ], ) def test_str_transform_functions(method, exp): @@ -2136,6 +2137,125 @@ def test_str_removesuffix(val): tm.assert_series_equal(result, expected) +@pytest.mark.parametrize("val", ["123abc", "abc"]) +def test_str_removeprefix(val): + ser = pd.Series([val, None], dtype=ArrowDtype(pa.string())) + result = ser.str.removeprefix("123") + expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("errors", ["ignore", "strict"]) +@pytest.mark.parametrize( + "encoding, exp", + [ + ["utf8", b"abc"], + ["utf32", b"\xff\xfe\x00\x00a\x00\x00\x00b\x00\x00\x00c\x00\x00\x00"], + ], +) +def test_str_encode(errors, encoding, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.encode(encoding, errors) + expected = pd.Series([exp, None], dtype=ArrowDtype(pa.binary())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("flags", [0, 1]) +def test_str_findall(flags): + ser = pd.Series(["abc", "efg", None], dtype=ArrowDtype(pa.string())) + result = ser.str.findall("b", flags=flags) + expected = pd.Series([["b"], [], None], dtype=ArrowDtype(pa.list_(pa.string()))) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["index", "rindex"]) +@pytest.mark.parametrize( + "start, end", + [ + [0, None], + [1, 4], + ], +) +def test_str_r_index(method, start, end): + ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)("c", start, end) + expected = pd.Series([2, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + with pytest.raises(ValueError, match="substring not found"): + getattr(ser.str, method)("foo", start, end) + + +@pytest.mark.parametrize("form", ["NFC", "NFKC"]) +def test_str_normalize(form): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.normalize(form) + expected = ser.copy() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start, end", + [ + [0, None], + [1, 4], + ], +) +def test_str_rfind(start, end): + ser = pd.Series(["abcba", "foo", None], dtype=ArrowDtype(pa.string())) + result = ser.str.rfind("c", start, end) + expected = pd.Series([2, -1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +def test_str_translate(): + ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string())) + result = ser.str.translate({97: "b"}) + expected = pd.Series(["bbcbb", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_str_wrap(): + ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string())) + result = ser.str.wrap(3) + expected = pd.Series(["abc\nba", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_get_dummies(): + ser = pd.Series(["a|b", None, "a|c"], dtype=ArrowDtype(pa.string())) + result = ser.str.get_dummies() + expected = pd.DataFrame( + [[True, True, False], [False, False, False], [True, False, True]], + dtype=ArrowDtype(pa.bool_()), + columns=["a", "b", "c"], + ) + tm.assert_frame_equal(result, expected) + + +def test_str_partition(): + ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string())) + result = ser.str.partition("b") + expected = pd.DataFrame( + [["a", "b", "cba"], [None, None, None]], dtype=ArrowDtype(pa.string()) + ) + tm.assert_frame_equal(result, expected) + + result = ser.str.partition("b", expand=False) + expected = pd.Series(ArrowExtensionArray(pa.array([["a", "b", "cba"], None]))) + tm.assert_series_equal(result, expected) + + result = ser.str.rpartition("b") + expected = pd.DataFrame( + [["abc", "b", "a"], [None, None, None]], dtype=ArrowDtype(pa.string()) + ) + tm.assert_frame_equal(result, expected) + + result = ser.str.rpartition("b", expand=False) + expected = pd.Series(ArrowExtensionArray(pa.array([["abc", "b", "a"], None]))) + tm.assert_series_equal(result, expected) + + def test_str_split(): # GH 52401 ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) @@ -2192,31 +2312,12 @@ def test_str_rsplit(): tm.assert_frame_equal(result, expected) -@pytest.mark.parametrize( - "method, args", - [ - ["partition", ("abc", False)], - ["rpartition", ("abc", False)], - ["removeprefix", ("abc",)], - ["casefold", ()], - ["encode", ("abc",)], - ["extract", (r"[ab](\d)",)], - ["findall", ("abc",)], - ["get_dummies", ()], - ["index", ("abc",)], - ["rindex", ("abc",)], - ["normalize", ("abc",)], - ["rfind", ("abc",)], - ["translate", ("abc",)], - ["wrap", ("abc",)], - ], -) -def test_str_unsupported_methods(method, args): +def test_str_unsupported_extract(): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) with pytest.raises( - NotImplementedError, match=f"str.{method} not supported with pd.ArrowDtype" + NotImplementedError, match="str.extract not supported with pd.ArrowDtype" ): - getattr(ser.str, method)(*args) + ser.str.extract(r"[ab](\d)") @pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"])