diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 7fc856be374e9..85993304b4407 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -284,6 +284,7 @@ Alternatively, copy on write can be enabled locally through: Other enhancements ^^^^^^^^^^^^^^^^^^ +- Added support for ``str`` accessor methods when using :class:`ArrowDtype` with a ``pyarrow.string`` type (:issue:`50325`) - Added support for ``dt`` accessor methods when using :class:`ArrowDtype` with a ``pyarrow.timestamp`` type (:issue:`50954`) - :func:`read_sas` now supports using ``encoding='infer'`` to correctly read and use the encoding specified by the sas file. (:issue:`48048`) - :meth:`.DataFrameGroupBy.quantile`, :meth:`.SeriesGroupBy.quantile` and :meth:`.DataFrameGroupBy.std` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 16cfb6d7c396b..df8d6172dd6f6 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1,10 +1,13 @@ from __future__ import annotations from copy import deepcopy +import re from typing import ( TYPE_CHECKING, Any, + Callable, Literal, + Sequence, TypeVar, cast, ) @@ -55,6 +58,7 @@ unpack_tuple_and_ellipses, validate_indices, ) +from pandas.core.strings.base import BaseStringArrayMethods from pandas.tseries.frequencies import to_offset @@ -165,7 +169,7 @@ def to_pyarrow_type( return None -class ArrowExtensionArray(OpsMixin, ExtensionArray): +class ArrowExtensionArray(OpsMixin, ExtensionArray, BaseStringArrayMethods): """ Pandas ExtensionArray backed by a PyArrow ChunkedArray. @@ -1463,6 +1467,317 @@ def _replace_with_mask( result[mask] = replacements return pa.array(result, type=values.type, from_pandas=True) + def _str_count(self, pat: str, flags: int = 0): + if flags: + raise NotImplementedError(f"count not implemented with {flags=}") + return type(self)(pc.count_substring_regex(self._data, pat)) + + def _str_pad( + self, + width: int, + side: Literal["left", "right", "both"] = "left", + fillchar: str = " ", + ): + if side == "left": + pa_pad = pc.utf8_lpad + elif side == "right": + pa_pad = pc.utf8_rpad + elif side == "both": + pa_pad = pc.utf8_center + else: + raise ValueError( + f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'" + ) + return type(self)(pa_pad(self._data, width=width, padding=fillchar)) + + def _str_contains( + self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True + ): + if flags: + raise NotImplementedError(f"contains not implemented with {flags=}") + + if regex: + pa_contains = pc.match_substring_regex + else: + pa_contains = pc.match_substring + result = pa_contains(self._data, pat, ignore_case=not case) + if not isna(na): + result = result.fill_null(na) + return type(self)(result) + + def _str_startswith(self, pat: str, na=None): + result = pc.starts_with(self._data, pattern=pat) + if not isna(na): + result = result.fill_null(na) + return type(self)(result) + + def _str_endswith(self, pat: str, na=None): + result = pc.ends_with(self._data, pattern=pat) + if not isna(na): + result = result.fill_null(na) + return type(self)(result) + + def _str_replace( + self, + pat: str | re.Pattern, + repl: str | Callable, + n: int = -1, + case: bool = True, + flags: int = 0, + regex: bool = True, + ): + if isinstance(pat, re.Pattern) or callable(repl) or not case or flags: + raise NotImplementedError( + "replace is not supported with a re.Pattern, callable repl, " + "case=False, or flags!=0" + ) + + func = pc.replace_substring_regex if regex else pc.replace_substring + result = func(self._data, pattern=pat, replacement=repl, max_replacements=n) + return type(self)(result) + + def _str_repeat(self, repeats: int | Sequence[int]): + if not isinstance(repeats, int): + raise NotImplementedError( + f"repeat is not implemented when repeats is {type(repeats).__name__}" + ) + elif pa_version_under7p0: + raise NotImplementedError("repeat is not implemented for pyarrow < 7") + else: + return type(self)(pc.binary_repeat(self._data, repeats)) + + def _str_match( + self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.startswith("^"): + pat = f"^{pat}" + return self._str_contains(pat, case, flags, na, regex=True) + + def _str_fullmatch( + self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.endswith("$") or pat.endswith("//$"): + pat = f"{pat}$" + return self._str_match(pat, case, flags, na) + + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if start != 0 and end is not None: + slices = pc.utf8_slice_codeunits(self._data, start, stop=end) + result = pc.find_substring(slices, sub) + not_found = pc.equal(result, -1) + offset_result = pc.add(result, end - start) + result = pc.if_else(not_found, result, offset_result) + elif start == 0 and end is None: + slices = self._data + result = pc.find_substring(slices, sub) + else: + raise NotImplementedError( + f"find not implemented with {sub=}, {start=}, {end=}" + ) + return type(self)(result) + + def _str_get(self, i: int): + lengths = pc.utf8_length(self._data) + if i >= 0: + out_of_bounds = pc.greater_equal(i, lengths) + start = i + stop = i + 1 + step = 1 + else: + out_of_bounds = pc.greater(-i, lengths) + start = i + stop = i - 1 + step = -1 + not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True)) + selected = pc.utf8_slice_codeunits( + self._data, start=start, stop=stop, step=step + ) + result = pa.array([None] * self._data.length(), type=self._data.type) + result = pc.if_else(not_out_of_bounds, selected, result) + return type(self)(result) + + def _str_join(self, sep: str): + return type(self)(pc.binary_join(self._data, sep)) + + def _str_partition(self, sep: str, expand: bool): + raise NotImplementedError( + "str.partition not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_rpartition(self, sep: str, expand: bool): + raise NotImplementedError( + "str.rpartition not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_slice( + self, start: int | None = None, stop: int | None = None, step: int | None = None + ): + if start is None: + start = 0 + if step is None: + step = 1 + return type(self)( + pc.utf8_slice_codeunits(self._data, start=start, stop=stop, step=step) + ) + + def _str_slice_replace( + self, start: int | None = None, stop: int | None = None, repl: str | None = None + ): + if repl is None: + repl = "" + if start is None: + start = 0 + return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl)) + + def _str_isalnum(self): + return type(self)(pc.utf8_is_alnum(self._data)) + + def _str_isalpha(self): + return type(self)(pc.utf8_is_alpha(self._data)) + + def _str_isdecimal(self): + return type(self)(pc.utf8_is_decimal(self._data)) + + def _str_isdigit(self): + return type(self)(pc.utf8_is_digit(self._data)) + + def _str_islower(self): + return type(self)(pc.utf8_is_lower(self._data)) + + def _str_isnumeric(self): + return type(self)(pc.utf8_is_numeric(self._data)) + + def _str_isspace(self): + return type(self)(pc.utf8_is_space(self._data)) + + def _str_istitle(self): + return type(self)(pc.utf8_is_title(self._data)) + + def _str_capitalize(self): + return type(self)(pc.utf8_capitalize(self._data)) + + def _str_title(self): + return type(self)(pc.utf8_title(self._data)) + + def _str_isupper(self): + return type(self)(pc.utf8_is_upper(self._data)) + + def _str_swapcase(self): + return type(self)(pc.utf8_swapcase(self._data)) + + def _str_len(self): + return type(self)(pc.utf8_length(self._data)) + + def _str_lower(self): + return type(self)(pc.utf8_lower(self._data)) + + def _str_upper(self): + return type(self)(pc.utf8_upper(self._data)) + + def _str_strip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_trim_whitespace(self._data) + else: + result = pc.utf8_trim(self._data, characters=to_strip) + return type(self)(result) + + def _str_lstrip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_ltrim_whitespace(self._data) + else: + result = pc.utf8_ltrim(self._data, characters=to_strip) + return type(self)(result) + + def _str_rstrip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_rtrim_whitespace(self._data) + else: + result = pc.utf8_rtrim(self._data, characters=to_strip) + return type(self)(result) + + def _str_removeprefix(self, prefix: str): + raise NotImplementedError( + "str.removeprefix not supported with pd.ArrowDtype(pa.string())." + ) + # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed + # starts_with = pc.starts_with(self._data, pattern=prefix) + # removed = pc.utf8_slice_codeunits(self._data, len(prefix)) + # result = pc.if_else(starts_with, removed, self._data) + # return type(self)(result) + + def _str_removesuffix(self, suffix: str): + ends_with = pc.ends_with(self._data, pattern=suffix) + removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix)) + result = pc.if_else(ends_with, removed, self._data) + return type(self)(result) + + def _str_casefold(self): + raise NotImplementedError( + "str.casefold not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_encode(self, encoding, errors: str = "strict"): + raise NotImplementedError( + "str.encode not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): + raise NotImplementedError( + "str.extract not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_findall(self, pat, flags: int = 0): + raise NotImplementedError( + "str.findall not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_get_dummies(self, sep: str = "|"): + raise NotImplementedError( + "str.get_dummies not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_index(self, sub, start: int = 0, end=None): + raise NotImplementedError( + "str.index not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_rindex(self, sub, start: int = 0, end=None): + raise NotImplementedError( + "str.rindex not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_normalize(self, form): + raise NotImplementedError( + "str.normalize not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_rfind(self, sub, start: int = 0, end=None): + raise NotImplementedError( + "str.rfind not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_split( + self, pat=None, n=-1, expand: bool = False, regex: bool | None = None + ): + raise NotImplementedError( + "str.split not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_rsplit(self, pat=None, n=-1): + raise NotImplementedError( + "str.rsplit not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_translate(self, table): + raise NotImplementedError( + "str.translate not supported with pd.ArrowDtype(pa.string())." + ) + + def _str_wrap(self, width, **kwargs): + raise NotImplementedError( + "str.wrap not supported with pd.ArrowDtype(pa.string())." + ) + @property def _dt_day(self): return type(self)(pc.day(self._data)) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 9ad92471b98b4..4d2b39ec61fca 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -60,7 +60,7 @@ def _chk_pyarrow_available() -> None: # fallback for the ones that pyarrow doesn't yet support -class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin): +class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringArray): """ Extension array for string data in a ``pyarrow.ChunkedArray``. @@ -117,6 +117,16 @@ def __init__(self, values) -> None: "ArrowStringArray requires a PyArrow (chunked) array of string type" ) + def __len__(self) -> int: + """ + Length of this array. + + Returns + ------- + length : int + """ + return len(self._data) + @classmethod def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False): from pandas.core.arrays.masked import BaseMaskedArray diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 70572897c1459..dfbb493636998 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -169,7 +169,7 @@ get_group_index_sorter, nargsort, ) -from pandas.core.strings import StringMethods +from pandas.core.strings.accessor import StringMethods from pandas.io.formats.printing import ( PrettyDict, diff --git a/pandas/core/series.py b/pandas/core/series.py index e4c7c4d3b3d73..d69c057c85783 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -163,7 +163,7 @@ ensure_key_mapped, nargsort, ) -from pandas.core.strings import StringMethods +from pandas.core.strings.accessor import StringMethods from pandas.core.tools.datetimes import to_datetime import pandas.io.formats.format as fmt diff --git a/pandas/core/strings/__init__.py b/pandas/core/strings/__init__.py index 28aba7c9ce0b3..eb650477c2b6b 100644 --- a/pandas/core/strings/__init__.py +++ b/pandas/core/strings/__init__.py @@ -26,8 +26,3 @@ # - PandasArray # - Categorical # - ArrowStringArray - -from pandas.core.strings.accessor import StringMethods -from pandas.core.strings.base import BaseStringArrayMethods - -__all__ = ["StringMethods", "BaseStringArrayMethods"] diff --git a/pandas/core/strings/base.py b/pandas/core/strings/base.py index c96e5a1abcf86..f1e716b64644a 100644 --- a/pandas/core/strings/base.py +++ b/pandas/core/strings/base.py @@ -246,7 +246,9 @@ def _str_removesuffix(self, suffix: str) -> Series: pass @abc.abstractmethod - def _str_split(self, pat=None, n=-1, expand: bool = False): + def _str_split( + self, pat=None, n=-1, expand: bool = False, regex: bool | None = None + ): pass @abc.abstractmethod diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 9406c7c2f59c6..50fb636c2beb8 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -21,6 +21,7 @@ StringIO, ) import pickle +import re import numpy as np import pytest @@ -1594,6 +1595,342 @@ def test_searchsorted_with_na_raises(data_for_sorting, as_series): arr.searchsorted(b) +@pytest.mark.parametrize("pat", ["abc", "a[a-z]{2}"]) +def test_str_count(pat): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.count(pat) + expected = pd.Series([1, None], dtype=ArrowDtype(pa.int32())) + tm.assert_series_equal(result, expected) + + +def test_str_count_flags_unsupported(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="count not"): + ser.str.count("abc", flags=1) + + +@pytest.mark.parametrize( + "side, str_func", [["left", "rjust"], ["right", "ljust"], ["both", "center"]] +) +def test_str_pad(side, str_func): + ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string())) + result = ser.str.pad(width=3, side=side, fillchar="x") + expected = pd.Series( + [getattr("a", str_func)(3, "x"), None], dtype=ArrowDtype(pa.string()) + ) + tm.assert_series_equal(result, expected) + + +def test_str_pad_invalid_side(): + ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(ValueError, match="Invalid side: foo"): + ser.str.pad(3, "foo", "x") + + +@pytest.mark.parametrize( + "pat, case, na, regex, exp", + [ + ["ab", False, None, False, [True, None]], + ["Ab", True, None, False, [False, None]], + ["ab", False, True, False, [True, True]], + ["a[a-z]{1}", False, None, True, [True, None]], + ["A[a-z]{1}", True, None, True, [False, None]], + ], +) +def test_str_contains(pat, case, na, regex, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.contains(pat, case=case, na=na, regex=regex) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +def test_str_contains_flags_unsupported(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="contains not"): + ser.str.contains("a", flags=1) + + +@pytest.mark.parametrize( + "side, pat, na, exp", + [ + ["startswith", "ab", None, [True, None]], + ["startswith", "b", False, [False, False]], + ["endswith", "b", True, [False, True]], + ["endswith", "bc", None, [True, None]], + ], +) +def test_str_start_ends_with(side, pat, na, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, side)(pat, na=na) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "arg_name, arg", + [["pat", re.compile("b")], ["repl", str], ["case", False], ["flags", 1]], +) +def test_str_replace_unsupported(arg_name, arg): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + kwargs = {"pat": "b", "repl": "x", "regex": True} + kwargs[arg_name] = arg + with pytest.raises(NotImplementedError, match="replace is not supported"): + ser.str.replace(**kwargs) + + +@pytest.mark.parametrize( + "pat, repl, n, regex, exp", + [ + ["a", "x", -1, False, ["xbxc", None]], + ["a", "x", 1, False, ["xbac", None]], + ["[a-b]", "x", -1, True, ["xxxc", None]], + ], +) +def test_str_replace(pat, repl, n, regex, exp): + ser = pd.Series(["abac", None], dtype=ArrowDtype(pa.string())) + result = ser.str.replace(pat, repl, n=n, regex=regex) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_str_repeat_unsupported(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="repeat is not"): + ser.str.repeat([1, 2]) + + +@pytest.mark.xfail( + pa_version_under7p0, + reason="Unsupported for pyarrow < 7", + raises=NotImplementedError, +) +def test_str_repeat(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.repeat(2) + expected = pd.Series(["abcabc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pat, case, na, exp", + [ + ["ab", False, None, [True, None]], + ["Ab", True, None, [False, None]], + ["bc", True, None, [False, None]], + ["ab", False, True, [True, True]], + ["a[a-z]{1}", False, None, [True, None]], + ["A[a-z]{1}", True, None, [False, None]], + ], +) +def test_str_match(pat, case, na, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.match(pat, case=case, na=na) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pat, case, na, exp", + [ + ["abc", False, None, [True, None]], + ["Abc", True, None, [False, None]], + ["bc", True, None, [False, None]], + ["ab", False, True, [True, True]], + ["a[a-z]{2}", False, None, [True, None]], + ["A[a-z]{1}", True, None, [False, None]], + ], +) +def test_str_fullmatch(pat, case, na, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.match(pat, case=case, na=na) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "sub, start, end, exp, exp_typ", + [["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [2, None], pa.int64()]], +) +def test_str_find(sub, start, end, exp, exp_typ): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub, start=start, end=end) + expected = pd.Series(exp, dtype=ArrowDtype(exp_typ)) + tm.assert_series_equal(result, expected) + + +def test_str_find_notimplemented(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="find not implemented"): + ser.str.find("ab", start=1) + + +@pytest.mark.parametrize( + "i, exp", + [ + [1, ["b", "e", None]], + [-1, ["c", "e", None]], + [2, ["c", None, None]], + [-3, ["a", None, None]], + [4, [None, None, None]], + ], +) +def test_str_get(i, exp): + ser = pd.Series(["abc", "de", None], dtype=ArrowDtype(pa.string())) + result = ser.str.get(i) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.xfail( + reason="TODO: StringMethods._validate should support Arrow list types", + raises=AttributeError, +) +def test_str_join(): + ser = pd.Series(ArrowExtensionArray(pa.array([list("abc"), list("123"), None]))) + result = ser.str.join("=") + expected = pd.Series(["a=b=c", "1=2=3", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start, stop, step, exp", + [ + [None, 2, None, ["ab", None]], + [None, 2, 1, ["ab", None]], + [1, 3, 1, ["bc", None]], + ], +) +def test_str_slice(start, stop, step, exp): + ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) + result = ser.str.slice(start, stop, step) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start, stop, repl, exp", + [ + [1, 2, "x", ["axcd", None]], + [None, 2, "x", ["xcd", None]], + [None, 2, None, ["cd", None]], + ], +) +def test_str_slice_replace(start, stop, repl, exp): + ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) + result = ser.str.slice_replace(start, stop, repl) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "value, method, exp", + [ + ["a1c", "isalnum", True], + ["!|,", "isalnum", False], + ["aaa", "isalpha", True], + ["!!!", "isalpha", False], + ["٠", "isdecimal", True], + ["~!", "isdecimal", False], + ["2", "isdigit", True], + ["~", "isdigit", False], + ["aaa", "islower", True], + ["aaA", "islower", False], + ["123", "isnumeric", True], + ["11I", "isnumeric", False], + [" ", "isspace", True], + ["", "isspace", False], + ["The That", "istitle", True], + ["the That", "istitle", False], + ["AAA", "isupper", True], + ["AAc", "isupper", False], + ], +) +def test_str_is_functions(value, method, exp): + ser = pd.Series([value, None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)() + expected = pd.Series([exp, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + ["capitalize", "Abc def"], + ["title", "Abc Def"], + ["swapcase", "AbC Def"], + ["lower", "abc def"], + ["upper", "ABC DEF"], + ], +) +def test_str_transform_functions(method, exp): + ser = pd.Series(["aBc dEF", None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)() + expected = pd.Series([exp, None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_str_len(): + ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) + result = ser.str.len() + expected = pd.Series([4, None], dtype=ArrowDtype(pa.int32())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, to_strip, val", + [ + ["strip", None, " abc "], + ["strip", "x", "xabcx"], + ["lstrip", None, " abc"], + ["lstrip", "x", "xabc"], + ["rstrip", None, "abc "], + ["rstrip", "x", "abcx"], + ], +) +def test_str_strip(method, to_strip, val): + ser = pd.Series([val, None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)(to_strip=to_strip) + expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("val", ["abc123", "abc"]) +def test_str_removesuffix(val): + ser = pd.Series([val, None], dtype=ArrowDtype(pa.string())) + result = ser.str.removesuffix("123") + expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, args", + [ + ["partition", ("abc", False)], + ["rpartition", ("abc", False)], + ["removeprefix", ("abc",)], + ["casefold", ()], + ["encode", ("abc",)], + ["extract", (r"[ab](\d)",)], + ["findall", ("abc",)], + ["get_dummies", ()], + ["index", ("abc",)], + ["rindex", ("abc",)], + ["normalize", ("abc",)], + ["rfind", ("abc",)], + ["split", ()], + ["rsplit", ()], + ["translate", ("abc",)], + ["wrap", ("abc",)], + ], +) +def test_str_unsupported_methods(method, args): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises( + NotImplementedError, match=f"str.{method} not supported with pd.ArrowDtype" + ): + getattr(ser.str, method)(*args) + + @pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"]) def test_duration_from_strings_with_nat(unit): # GH51175 diff --git a/pandas/tests/strings/conftest.py b/pandas/tests/strings/conftest.py index cdc2b876194e6..3e1ee89e9a841 100644 --- a/pandas/tests/strings/conftest.py +++ b/pandas/tests/strings/conftest.py @@ -2,7 +2,7 @@ import pytest from pandas import Series -from pandas.core import strings +from pandas.core.strings.accessor import StringMethods _any_string_method = [ ("cat", (), {"sep": ","}), @@ -89,9 +89,7 @@ ) ) ids, _, _ = zip(*_any_string_method) # use method name as fixture-id -missing_methods = { - f for f in dir(strings.StringMethods) if not f.startswith("_") -} - set(ids) +missing_methods = {f for f in dir(StringMethods) if not f.startswith("_")} - set(ids) # test that the above list captures all methods of StringMethods assert not missing_methods diff --git a/pandas/tests/strings/test_api.py b/pandas/tests/strings/test_api.py index 088affcb0506f..88d928ceecc43 100644 --- a/pandas/tests/strings/test_api.py +++ b/pandas/tests/strings/test_api.py @@ -8,13 +8,13 @@ _testing as tm, get_option, ) -from pandas.core import strings +from pandas.core.strings.accessor import StringMethods def test_api(any_string_dtype): # GH 6106, GH 9322 - assert Series.str is strings.StringMethods - assert isinstance(Series([""], dtype=any_string_dtype).str, strings.StringMethods) + assert Series.str is StringMethods + assert isinstance(Series([""], dtype=any_string_dtype).str, StringMethods) def test_api_mi_raises(): @@ -44,7 +44,7 @@ def test_api_per_dtype(index_or_series, dtype, any_skipna_inferred_dtype): ] if inferred_dtype in types_passing_constructor: # GH 6106 - assert isinstance(t.str, strings.StringMethods) + assert isinstance(t.str, StringMethods) else: # GH 9184, GH 23011, GH 23163 msg = "Can only use .str accessor with string values.*" @@ -137,7 +137,7 @@ def test_api_for_categorical(any_string_method, any_string_dtype, request): s = Series(list("aabb"), dtype=any_string_dtype) s = s + " " + s c = s.astype("category") - assert isinstance(c.str, strings.StringMethods) + assert isinstance(c.str, StringMethods) method_name, args, kwargs = any_string_method diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 9340fea14f801..b863425a24183 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -13,6 +13,7 @@ Series, ) import pandas._testing as tm +from pandas.core.strings.accessor import StringMethods @pytest.mark.parametrize("pattern", [0, True, Series(["foo", "bar"])]) @@ -598,8 +599,6 @@ def test_normalize_index(): ], ) def test_index_str_accessor_visibility(values, inferred_type, index_or_series): - from pandas.core.strings import StringMethods - obj = index_or_series(values) if index_or_series is Index: assert obj.inferred_type == inferred_type