From f8930b4f8f118eb9b57b29947a82190a8e9094be Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 3 Feb 2023 15:09:36 -0800 Subject: [PATCH 01/19] Remove __init__ import --- pandas/core/indexes/base.py | 2 +- pandas/core/series.py | 2 +- pandas/core/strings/__init__.py | 5 ----- pandas/tests/strings/test_api.py | 10 +++++----- pandas/tests/strings/test_strings.py | 2 +- 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 0caa8005f1ebc..d0ec8b60bb11e 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -170,7 +170,7 @@ get_group_index_sorter, nargsort, ) -from pandas.core.strings import StringMethods +from pandas.core.strings.accessor import StringMethods from pandas.io.formats.printing import ( PrettyDict, diff --git a/pandas/core/series.py b/pandas/core/series.py index abe31d0dbd52a..3cd97cc705f83 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -161,7 +161,7 @@ ensure_key_mapped, nargsort, ) -from pandas.core.strings import StringMethods +from pandas.core.strings.accessor import StringMethods from pandas.core.tools.datetimes import to_datetime import pandas.io.formats.format as fmt diff --git a/pandas/core/strings/__init__.py b/pandas/core/strings/__init__.py index 28aba7c9ce0b3..eb650477c2b6b 100644 --- a/pandas/core/strings/__init__.py +++ b/pandas/core/strings/__init__.py @@ -26,8 +26,3 @@ # - PandasArray # - Categorical # - ArrowStringArray - -from pandas.core.strings.accessor import StringMethods -from pandas.core.strings.base import BaseStringArrayMethods - -__all__ = ["StringMethods", "BaseStringArrayMethods"] diff --git a/pandas/tests/strings/test_api.py b/pandas/tests/strings/test_api.py index 7a6c7e69047bc..036eaf5f0cae4 100644 --- a/pandas/tests/strings/test_api.py +++ b/pandas/tests/strings/test_api.py @@ -8,14 +8,14 @@ _testing as tm, get_option, ) -from pandas.core import strings +from pandas.core.strings.accessor import StringMethods def test_api(any_string_dtype): # GH 6106, GH 9322 - assert Series.str is strings.StringMethods - assert isinstance(Series([""], dtype=any_string_dtype).str, strings.StringMethods) + assert Series.str is StringMethods + assert isinstance(Series([""], dtype=any_string_dtype).str, StringMethods) def test_api_mi_raises(): @@ -45,7 +45,7 @@ def test_api_per_dtype(index_or_series, dtype, any_skipna_inferred_dtype): ] if inferred_dtype in types_passing_constructor: # GH 6106 - assert isinstance(t.str, strings.StringMethods) + assert isinstance(t.str, StringMethods) else: # GH 9184, GH 23011, GH 23163 msg = "Can only use .str accessor with string values.*" @@ -138,7 +138,7 @@ def test_api_for_categorical(any_string_method, any_string_dtype, request): s = Series(list("aabb"), dtype=any_string_dtype) s = s + " " + s c = s.astype("category") - assert isinstance(c.str, strings.StringMethods) + assert isinstance(c.str, StringMethods) method_name, args, kwargs = any_string_method diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index a9335e156d9db..4a296f197e39f 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -599,7 +599,7 @@ def test_normalize_index(): ], ) def test_index_str_accessor_visibility(values, inferred_type, index_or_series): - from pandas.core.strings import StringMethods + from pandas.core.strings.accessor import StringMethods obj = index_or_series(values) if index_or_series is Index: From 0c31e1e26b38c64935773ed69770fa89de8b69b6 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 3 Feb 2023 20:23:49 -0800 Subject: [PATCH 02/19] Add base class methods --- pandas/core/arrays/arrow/array.py | 283 +++++++++++++++++++++++++++++- 1 file changed, 282 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 9247d26fc846d..661549e3825e3 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1,10 +1,13 @@ from __future__ import annotations from copy import deepcopy +import re from typing import ( TYPE_CHECKING, Any, + Callable, Literal, + Sequence, TypeVar, cast, ) @@ -53,6 +56,7 @@ unpack_tuple_and_ellipses, validate_indices, ) +from pandas.core.strings.base import BaseStringArrayMethods if not pa_version_under6p0: import pyarrow as pa @@ -151,7 +155,7 @@ def to_pyarrow_type( return None -class ArrowExtensionArray(OpsMixin, ExtensionArray): +class ArrowExtensionArray(OpsMixin, ExtensionArray, BaseStringArrayMethods): """ Pandas ExtensionArray backed by a PyArrow ChunkedArray. @@ -1429,3 +1433,280 @@ def _replace_with_mask( result = np.array(values, dtype=object) result[mask] = replacements return pa.array(result, type=values.type, from_pandas=True) + + def _str_count(self, pat: str, flags: int = 0): + if flags: + raise NotImplementedError(f"count not implemented with {flags=}") + return type(self)(pc.count_substring_regex(self._data, pat)) + + def _str_pad( + self, + width: int, + side: Literal["left", "right", "both"] = "left", + fillchar: str = " ", + ): + if side == "left": + pa_pad = pc.utf8_lpad + elif side == "right": + pa_pad = pc.utf8_rpad + elif side == "both": + pa_pad = pc.utf8_center + else: + raise ValueError( + f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'" + ) + return type(self)(pa_pad(self._data, width=width, padding=fillchar)) + + def _str_contains( + self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True + ): + if flags: + raise NotImplementedError(f"contains not implemented with {flags=}") + + if regex: + pa_contains = pc.match_substring_regex + else: + pa_contains = pc.match_substring + result = pa_contains(self._data, pat, ignore_case=not case) + if not isna(na): + result = result.fill_null(na) + return type(self)(result) + + def _str_startswith(self, pat: str, na=None): + result = pc.starts_with(self._data, pattern=pat) + if na is not None: + result = result.fill_null(na) + return type(self)(result) + + def _str_endswith(self, pat: str, na=None): + result = pc.ends_with(self._data, pattern=pat) + if na is not None: + result = result.fill_null(na) + return type(self)(result) + + def _str_replace( + self, + pat: str | re.Pattern, + repl: str | Callable, + n: int = -1, + case: bool = True, + flags: int = 0, + regex: bool = True, + ): + if isinstance(pat, re.Pattern) or callable(repl) or not case or flags: + raise NotImplementedError( + "replace is not with a regex pat, callable repl, case=False, " + "or flags!=0" + ) + + func = pc.replace_substring_regex if regex else pc.replace_substring + result = func(self._data, pattern=pat, replacement=repl, max_replacements=n) + return type(self)(result) + + def _str_repeat(self, repeats: int | Sequence[int]): + if not isinstance(repeats, int): + raise NotImplementedError( + f"repeat is not implemented when repeats is {type(repeats).__name__}" + ) + elif pa_version_under7p0: + raise NotImplementedError("repeat is not implemented for pyarrow < 7") + else: + return type(self)(pc.binary_repeat(self._data, repeats)) + + def _str_match( + self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.startswith("^"): + pat = f"^{pat}" + return self._str_contains(pat, case, flags, na, regex=True) + + def _str_fullmatch( + self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.endswith("$") or pat.endswith("//$"): + pat = f"{pat}$" + return self._str_match(pat, case, flags, na) + + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if start != 0 and end is not None: + slices = pc.utf8_slice_codeunits(self._data, start, stop=end) + result = pc.find_substring(slices, sub) + not_found = pc.equal(result, -1) + offset_result = pc.add(result, end - start) + result = pc.if_else(not_found, result, offset_result) + elif start == 0 and end is None: + slices = self._data + result = pc.find_substring(slices, sub) + else: + raise NotImplementedError( + f"find not implemented with {sub=}, {start=}, {end=}" + ) + return type(self)(result) + + def _str_get(self, i: int): + lengths = pc.utf8_length(self._data) + if i >= 0: + out_of_bounds = pc.greater_equal(i, lengths) + start = i + stop = i + 1 + step = 1 + else: + out_of_bounds = pc.greater(-i, lengths) + start = i + stop = i - 1 + step = -1 + not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True)) + selected = pc.utf8_slice_codeunits( + self._data, start=start, stop=stop, step=step + ) + result = pa.array([None] * self._data.length(), type=self._data.type) + result = pc.if_else(not_out_of_bounds, selected, result) + return type(self)(result) + + def _str_join(self, sep: str): + return type(self)(pc.binary_join(self._data, sep)) + + def _str_partition(self, sep: str, expand: bool): + raise NotImplementedError + + def _str_rpartition(self, sep: str, expand: bool): + raise NotImplementedError + + def _str_slice( + self, start: int | None = None, stop: int | None = None, step: int | None = None + ): + if start is None: + start = 0 + if step is None: + step = 1 + return type(self)( + pc.utf8_slice_codeunits(self._data, start=start, stop=stop, step=step) + ) + + def _str_slice_replace( + self, start: int | None = None, stop: int | None = None, repl: str | None = None + ): + if repl is None: + repl = "" + if start is None: + start = 0 + return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl)) + + def _str_isalnum(self): + return type(self)(pc.utf8_is_alnum(self._data)) + + def _str_isalpha(self): + return type(self)(pc.utf8_is_alpha(self._data)) + + def _str_isdecimal(self): + return type(self)(pc.utf8_is_decimal(self._data)) + + def _str_isdigit(self): + return type(self)(pc.utf8_is_digit(self._data)) + + def _str_islower(self): + return type(self)(pc.utf8_is_lower(self._data)) + + def _str_isnumeric(self): + return type(self)(pc.utf8_is_numeric(self._data)) + + def _str_isspace(self): + return type(self)(pc.utf8_is_space(self._data)) + + def _str_istitle(self): + return type(self)(pc.utf8_is_title(self._data)) + + def _str_capitalize(self): + return type(self)(pc.utf8_capitalize(self._data)) + + def _str_title(self): + return type(self)(pc.utf8_title(self._data)) + + def _str_isupper(self): + return type(self)(pc.utf8_is_upper(self._data)) + + def _str_swapcase(self): + return type(self)(pc.utf8_swapcase(self._data)) + + def _str_len(self): + return type(self)(pc.utf8_length(self._data)) + + def _str_lower(self): + return type(self)(pc.utf8_lower(self._data)) + + def _str_upper(self): + return type(self)(pc.utf8_upper(self._data)) + + def _str_strip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_trim_whitespace(self._data) + else: + result = pc.utf8_trim(self._data, characters=to_strip) + return type(self)(result) + + def _str_lstrip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_ltrim_whitespace(self._data) + else: + result = pc.utf8_ltrim(self._data, characters=to_strip) + return type(self)(result) + + def _str_rstrip(self, to_strip=None): + if to_strip is None: + result = pc.utf8_rtrim_whitespace(self._data) + else: + result = pc.utf8_rtrim(self._data, characters=to_strip) + return type(self)(result) + + def _str_removeprefix(self, prefix: str) -> Series: + raise NotImplementedError + # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed + # starts_with = pc.starts_with(self._data, pattern=prefix) + # removed = pc.utf8_slice_codeunits(self._data, len(prefix)) + # result = pc.if_else(starts_with, removed, self._data) + # return type(self)(result) + + def _str_removesuffix(self, suffix: str) -> Series: + ends_with = pc.ends_with(self._data, pattern=suffix) + removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix)) + result = pc.if_else(ends_with, removed, self._data) + return type(self)(result) + + def _str_casefold(self): + raise NotImplementedError + + def _str_encode(self, encoding, errors: str = "strict"): + raise NotImplementedError + + def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): + raise NotImplementedError + + def _str_findall(self, pat, flags: int = 0): + raise NotImplementedError + + def _str_get_dummies(self, sep: str = "|"): + raise NotImplementedError + + def _str_index(self, sub, start: int = 0, end=None): + raise NotImplementedError + + def _str_rindex(self, sub, start: int = 0, end=None): + raise NotImplementedError + + def _str_normalize(self, form): + raise NotImplementedError + + def _str_rfind(self, sub, start: int = 0, end=None): + raise NotImplementedError + + def _str_split(self, pat=None, n=-1, expand: bool = False): + raise NotImplementedError + + def _str_rsplit(self, pat=None, n=-1): + raise NotImplementedError + + def _str_translate(self, table): + raise NotImplementedError + + def _str_wrap(self, width, **kwargs): + raise NotImplementedError From f8aa6e5355cdd50cb2c57cc889c38df6cf5166a7 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 3 Feb 2023 20:28:22 -0800 Subject: [PATCH 03/19] Adapt for groupby --- pandas/core/groupby/ops.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index bff61ec135d74..5b87938d62df4 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -56,6 +56,7 @@ is_numeric_dtype, is_period_dtype, is_sparse, + is_string_dtype, is_timedelta64_dtype, needs_i8_conversion, ) @@ -76,7 +77,6 @@ BaseMaskedArray, BaseMaskedDtype, ) -from pandas.core.arrays.string_ import StringDtype from pandas.core.frame import DataFrame from pandas.core.groupby import grouper from pandas.core.indexes.api import ( @@ -96,6 +96,7 @@ ) if TYPE_CHECKING: + from pandas.core.arrays.string_ import StringDtype from pandas.core.generic import NDFrame @@ -387,7 +388,7 @@ def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray: # All of the functions implemented here are ordinal, so we can # operate on the tz-naive equivalents npvalues = values._ndarray.view("M8[ns]") - elif isinstance(values.dtype, StringDtype): + elif is_string_dtype(values.dtype): # StringArray npvalues = values.to_numpy(object, na_value=np.nan) else: @@ -405,7 +406,7 @@ def _reconstruct_ea_result( """ dtype: BaseMaskedDtype | StringDtype - if isinstance(values.dtype, StringDtype): + if is_string_dtype(values.dtype): dtype = values.dtype string_array_cls = dtype.construct_array_type() return string_array_cls._from_sequence(res_values, dtype=dtype) From 092f58ffe44e3b60a046a46cb3271d6b13e750e3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 3 Feb 2023 21:04:46 -0800 Subject: [PATCH 04/19] Start adding test pt 1 --- pandas/core/arrays/arrow/array.py | 4 +- pandas/tests/extension/test_arrow.py | 55 ++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 661549e3825e3..890cb9fc752af 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1474,13 +1474,13 @@ def _str_contains( def _str_startswith(self, pat: str, na=None): result = pc.starts_with(self._data, pattern=pat) - if na is not None: + if not isna(na): result = result.fill_null(na) return type(self)(result) def _str_endswith(self, pat: str, na=None): result = pc.ends_with(self._data, pattern=pat) - if na is not None: + if not isna(na): result = result.fill_null(na) return type(self)(result) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index e31d8605eeb06..884d2cffc1dba 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1657,3 +1657,58 @@ def test_searchsorted_with_na_raises(data_for_sorting, as_series): ) with pytest.raises(ValueError, match=msg): arr.searchsorted(b) + + +@pytest.mark.parametrize("pat", ["abc", "a[a-z]{2}"]) +def test_str_count(pat): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.count(pat) + expected = pd.Series([1, None], dtype=ArrowDtype(pa.int32())) + tm.assert_series_equal(result, expected) + + +def test_str_count_flags_unsupported(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="count not"): + ser.str.count("abc", flags=1) + + +@pytest.mark.parametrize( + "side, str_func", [["left", "rjust"], ["right", "ljust"], ["both", "center"]] +) +def test_str_pad(side, str_func): + ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string())) + result = ser.str.pad(width=3, side=side, fillchar="x") + expected = pd.Series( + [getattr("a", str_func)(3, "x"), None], dtype=ArrowDtype(pa.string()) + ) + tm.assert_series_equal(result, expected) + + +def test_str_pad_invalid_side(): + ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(ValueError, match="Invalid side: foo"): + ser.str.pad(3, "foo", "x") + + +@pytest.mark.parametrize( + "pat, case, na, regex, exp", + [ + ["ab", False, None, False, [True, None]], + ["Ab", True, None, False, [False, None]], + ["ab", False, True, False, [True, True]], + ["a[a-z]{1}", False, None, True, [True, None]], + ["A[a-z]{1}", True, None, True, [False, None]], + ], +) +def test_str_contains(pat, case, na, regex, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.contains(pat, case=case, na=na, regex=regex) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +def test_str_contains_flags_unsupported(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="contains not"): + ser.str.contains("a", flags=1) From 10ad91239568b0d5b5c3ba0b2b8ef74d745a07c0 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 3 Feb 2023 21:35:13 -0800 Subject: [PATCH 05/19] Add more tests part 2 --- pandas/core/arrays/arrow/array.py | 4 +-- pandas/tests/extension/test_arrow.py | 44 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 890cb9fc752af..d76777fbde748 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1495,8 +1495,8 @@ def _str_replace( ): if isinstance(pat, re.Pattern) or callable(repl) or not case or flags: raise NotImplementedError( - "replace is not with a regex pat, callable repl, case=False, " - "or flags!=0" + "replace is not supported with a re.Pattern, callable repl, " + "case=False, or flags!=0" ) func = pc.replace_substring_regex if regex else pc.replace_substring diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 884d2cffc1dba..e859efd0948e3 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -21,6 +21,7 @@ StringIO, ) import pickle +import re import numpy as np import pytest @@ -1712,3 +1713,46 @@ def test_str_contains_flags_unsupported(): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) with pytest.raises(NotImplementedError, match="contains not"): ser.str.contains("a", flags=1) + + +@pytest.mark.parametrize( + "side, pat, na, exp", + [ + ["startswith", "ab", None, [True, None]], + ["startswith", "b", False, [False, False]], + ["endswith", "b", True, [False, True]], + ["endswith", "bc", None, [True, None]], + ], +) +def test_str_start_ends_with(side, pat, na, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, side)(pat, na=na) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "arg_name, arg", + [["pat", re.compile("b")], ["repl", str], ["case", False], ["flags", 1]], +) +def test_str_replace_unsupported(arg_name, arg): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + kwargs = {"pat": "b", "repl": "x", "regex": True} + kwargs[arg_name] = arg + with pytest.raises(NotImplementedError, match="replace is not supported"): + ser.str.replace(**kwargs) + + +@pytest.mark.parametrize( + "pat, repl, n, regex, exp", + [ + ["a", "x", -1, False, ["xbxc", None]], + ["a", "x", 1, False, ["xbac", None]], + ["[a-b]", "x", -1, True, ["xxxc", None]], + ], +) +def test_str_replace(pat, repl, n, regex, exp): + ser = pd.Series(["abac", None], dtype=ArrowDtype(pa.string())) + result = ser.str.replace(pat, repl, n=n, regex=regex) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) From 0ae2e1d4275bd74c42d0e64c2195fcf4de27127e Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 3 Feb 2023 21:53:57 -0800 Subject: [PATCH 06/19] Test pt 3 --- pandas/tests/extension/test_arrow.py | 54 ++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index e859efd0948e3..d9604d81c2029 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1756,3 +1756,57 @@ def test_str_replace(pat, repl, n, regex, exp): result = ser.str.replace(pat, repl, n=n, regex=regex) expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) tm.assert_series_equal(result, expected) + + +def test_str_repeat_unsupported(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="repeat is not"): + ser.str.repeat([1, 2]) + + +@pytest.mark.xfail( + pa_version_under7p0, + reason="Unsupported for pyarrow < 7", + raises=NotImplementedError, +) +def test_str_repeat(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.repeat(2) + expected = pd.Series(["abcabc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pat, case, na, exp", + [ + ["ab", False, None, [True, None]], + ["Ab", True, None, [False, None]], + ["bc", True, None, [False, None]], + ["ab", False, True, [True, True]], + ["a[a-z]{1}", False, None, [True, None]], + ["A[a-z]{1}", True, None, [False, None]], + ], +) +def test_str_match(pat, case, na, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.match(pat, case=case, na=na) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pat, case, na, exp", + [ + ["abc", False, None, [True, None]], + ["Abc", True, None, [False, None]], + ["bc", True, None, [False, None]], + ["ab", False, True, [True, True]], + ["a[a-z]{2}", False, None, [True, None]], + ["A[a-z]{1}", True, None, [False, None]], + ], +) +def test_str_fullmatch(pat, case, na, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.match(pat, case=case, na=na) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) From 37d0add60925888b5970ccb29e5f04f34bc0d973 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 6 Feb 2023 16:07:29 -0800 Subject: [PATCH 07/19] More tests --- pandas/core/arrays/arrow/array.py | 4 +- pandas/tests/extension/test_arrow.py | 67 ++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index d76777fbde748..fda1971c9f7c0 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1567,10 +1567,10 @@ def _str_join(self, sep: str): return type(self)(pc.binary_join(self._data, sep)) def _str_partition(self, sep: str, expand: bool): - raise NotImplementedError + raise NotImplementedError("str.partition not supported.") def _str_rpartition(self, sep: str, expand: bool): - raise NotImplementedError + raise NotImplementedError("str.rpartition not supported.") def _str_slice( self, start: int | None = None, stop: int | None = None, step: int | None = None diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index d9604d81c2029..32830bc255281 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1810,3 +1810,70 @@ def test_str_fullmatch(pat, case, na, exp): result = ser.str.match(pat, case=case, na=na) expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "sub, start, end, exp, exp_typ", + [["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [2, None], pa.int64()]], +) +def test_str_find(sub, start, end, exp, exp_typ): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub, start=start, end=end) + expected = pd.Series(exp, dtype=ArrowDtype(exp_typ)) + tm.assert_series_equal(result, expected) + + +def test_str_find_notimplemented(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="find not implemented"): + ser.str.find("ab", start=1) + + +@pytest.mark.parametrize( + "i, exp", + [ + [1, ["b", "e", None]], + [-1, ["c", "e", None]], + [2, ["c", None, None]], + [-3, ["a", None, None]], + [4, [None, None, None]], + ], +) +def test_str_get(i, exp): + ser = pd.Series(["abc", "de", None], dtype=ArrowDtype(pa.string())) + result = ser.str.get(i) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.xfail( + reason="TODO: StringMethods._validate should support Arrow list types", + raises=AttributeError, +) +def test_str_join(): + ser = pd.Series(ArrowExtensionArray(pa.array([list("abc"), list("123"), None]))) + result = ser.str.join("=") + expected = pd.Series(["a=b=c", "1=2=3", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["partition", "rpartition"]) +def test_str_partition_unsupported(method): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="str."): + getattr(ser.str, method)("", False) + + +@pytest.mark.parametrize( + "start, stop, step, exp", + [ + [None, 2, None, ["ab", None]], + [None, 2, 1, ["ab", None]], + [None, 2, 1, ["bc", None]], + ], +) +def test_str_slice(start, stop, step, exp): + ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) + result = ser.str.slice(start, stop, step) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) From 4e6d24e8e18d880ce412b7c173b80dbe23b499ed Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 6 Feb 2023 17:18:32 -0800 Subject: [PATCH 08/19] finish tests --- pandas/core/arrays/arrow/array.py | 32 +++---- pandas/core/strings/base.py | 4 +- pandas/tests/extension/test_arrow.py | 129 +++++++++++++++++++++++++-- 3 files changed, 142 insertions(+), 23 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index fda1971c9f7c0..5657de5728b31 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1659,7 +1659,7 @@ def _str_rstrip(self, to_strip=None): return type(self)(result) def _str_removeprefix(self, prefix: str) -> Series: - raise NotImplementedError + raise NotImplementedError("str.removeprefix not supported.") # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed # starts_with = pc.starts_with(self._data, pattern=prefix) # removed = pc.utf8_slice_codeunits(self._data, len(prefix)) @@ -1673,40 +1673,42 @@ def _str_removesuffix(self, suffix: str) -> Series: return type(self)(result) def _str_casefold(self): - raise NotImplementedError + raise NotImplementedError("str.casefold not supported.") def _str_encode(self, encoding, errors: str = "strict"): - raise NotImplementedError + raise NotImplementedError("str.encode not supported.") def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): - raise NotImplementedError + raise NotImplementedError("str.extract not supported.") def _str_findall(self, pat, flags: int = 0): - raise NotImplementedError + raise NotImplementedError("str.findall not supported.") def _str_get_dummies(self, sep: str = "|"): - raise NotImplementedError + raise NotImplementedError("str.get_dummies not supported.") def _str_index(self, sub, start: int = 0, end=None): - raise NotImplementedError + raise NotImplementedError("str.index not supported.") def _str_rindex(self, sub, start: int = 0, end=None): - raise NotImplementedError + raise NotImplementedError("str.rindex not supported.") def _str_normalize(self, form): - raise NotImplementedError + raise NotImplementedError("str.normalize not supported.") def _str_rfind(self, sub, start: int = 0, end=None): - raise NotImplementedError + raise NotImplementedError("str.rfind not supported.") - def _str_split(self, pat=None, n=-1, expand: bool = False): - raise NotImplementedError + def _str_split( + self, pat=None, n=-1, expand: bool = False, regex: bool | None = None + ): + raise NotImplementedError("str.split not supported.") def _str_rsplit(self, pat=None, n=-1): - raise NotImplementedError + raise NotImplementedError("str.rsplit not supported.") def _str_translate(self, table): - raise NotImplementedError + raise NotImplementedError("str.translate not supported.") def _str_wrap(self, width, **kwargs): - raise NotImplementedError + raise NotImplementedError("str.wrap not supported.") diff --git a/pandas/core/strings/base.py b/pandas/core/strings/base.py index c96e5a1abcf86..f1e716b64644a 100644 --- a/pandas/core/strings/base.py +++ b/pandas/core/strings/base.py @@ -246,7 +246,9 @@ def _str_removesuffix(self, suffix: str) -> Series: pass @abc.abstractmethod - def _str_split(self, pat=None, n=-1, expand: bool = False): + def _str_split( + self, pat=None, n=-1, expand: bool = False, regex: bool | None = None + ): pass @abc.abstractmethod diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 32830bc255281..426c839ef8b2c 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1857,13 +1857,6 @@ def test_str_join(): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize("method", ["partition", "rpartition"]) -def test_str_partition_unsupported(method): - ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) - with pytest.raises(NotImplementedError, match="str."): - getattr(ser.str, method)("", False) - - @pytest.mark.parametrize( "start, stop, step, exp", [ @@ -1877,3 +1870,125 @@ def test_str_slice(start, stop, step, exp): result = ser.str.slice(start, stop, step) expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start, stop, repl, exp", + [ + [1, 2, "x", ["axcd", None]], + [None, 2, "x", ["xcd", None]], + [None, 2, None, ["cd", None]], + ], +) +def test_str_slice_replace(start, stop, repl, exp): + ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) + result = ser.str.slice_replace(start, stop, repl) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "value, method, exp", + [ + ["a1c", "isalnum", True], + ["!|,", "isalnum", False], + ["aaa", "isalpha", True], + ["!!!", "isalpha", False], + ["٠", "isdecimal", True], + ["~!", "isdecimal", False], + ["2", "isdigit", True], + ["~", "isdigit", False], + ["aaa", "islower", True], + ["aaA", "islower", False], + ["123", "isnumeric", True], + ["11I", "isnumeric", False], + [" ", "isspace", True], + ["", "isspace", False], + ["The That", "istitle", True], + ["the That", "istitle", False], + ["AAA", "isupper", True], + ["AAc", "isupper", False], + ], +) +def test_str_is_functions(value, method, exp): + ser = pd.Series([value, None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)() + expected = pd.Series([exp, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + ["capitalize", "Abc def"], + ["title", "Abc Def"], + ["swapcase", "AbC Def"], + ["lower", "abc def"], + ["upper", "ABC DEF"], + ], +) +def test_str_transform_functions(method, exp): + ser = pd.Series(["aBc dEF", None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)() + expected = pd.Series([exp, None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_str_len(): + ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) + result = ser.str.len() + expected = pd.Series([4, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, to_strip, val", + [ + ["strip", None, " abc "], + ["strip", "x", "xabcx"], + ["lstrip", None, " abc"], + ["lstrip", "x", "xabc"], + ["rstrip", None, "abc "], + ["rstrip", "x", "abcx"], + ], +) +def test_str_strip(method, to_strip, val): + ser = pd.Series([val, None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)(to_strip=to_strip) + expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("val", ["abc123", "abc"]) +def test_str_removesuffix(val): + ser = pd.Series([val, None], dtype=ArrowDtype(pa.string())) + result = ser.str.removesuffix("123") + expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, args", + [ + ["partition", ("abc", False)], + ["rpartition", ("abc", False)], + ["removeprefix", ("abc",)], + ["casefold", ()], + ["encode", ("abc",)], + ["extract", (r"[ab](\d)",)], + ["findall", ("abc",)], + ["get_dummies", ()], + ["index", ("abc",)], + ["rindex", ("abc",)], + ["normalize", ("abc",)], + ["rfind", ("abc",)], + ["split", ()], + ["rsplit", ()], + ["translate", ("abc",)], + ["wrap", ("abc",)], + ], +) +def test_str_unsupported_methods(method, args): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match=f"str.{method} not supported."): + getattr(ser.str, method)(*args) From c74473f191124edf057c668fab9a1d9dcdb20971 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 6 Feb 2023 18:18:13 -0800 Subject: [PATCH 09/19] xfail dask test due to moved path --- pandas/tests/test_downstream.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pandas/tests/test_downstream.py b/pandas/tests/test_downstream.py index fea10075a0ace..a38850b80565d 100644 --- a/pandas/tests/test_downstream.py +++ b/pandas/tests/test_downstream.py @@ -17,6 +17,7 @@ Series, ) import pandas._testing as tm +from pandas.util.version import Version def import_module(name): @@ -34,7 +35,7 @@ def df(): return DataFrame({"A": [1, 2, 3]}) -def test_dask(df): +def test_dask(df, request): # dask sets "compute.use_numexpr" to False, so catch the current value # and ensure to reset it afterwards to avoid impacting other tests @@ -44,6 +45,14 @@ def test_dask(df): toolz = import_module("toolz") # noqa:F841 dask = import_module("dask") # noqa:F841 + if Version(dask.__version__) < Version("2023.1.1"): + request.node.add_marker( + pytest.mark.xfail( + reason="Used pandas.core.strings.StringMethods which moved", + raises=AttributeError, + ) + ) + import dask.dataframe as dd ddf = dd.from_pandas(df, npartitions=3) From d7a463fcef1bf088a7801e3b26b4725d759b2dac Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 6 Feb 2023 18:21:50 -0800 Subject: [PATCH 10/19] Add whatsnew --- doc/source/whatsnew/v2.0.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index addebad8b45af..b5fcaf2f350f0 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -263,6 +263,7 @@ Alternatively, copy on write can be enabled locally through: Other enhancements ^^^^^^^^^^^^^^^^^^ +- Added support for ``str`` accessor methods when using ``pd.ArrowDtype(pyarrow.string())`` (:issue:`50325`) - :func:`read_sas` now supports using ``encoding='infer'`` to correctly read and use the encoding specified by the sas file. (:issue:`48048`) - :meth:`.DataFrameGroupBy.quantile`, :meth:`.SeriesGroupBy.quantile` and :meth:`.DataFrameGroupBy.std` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`) - :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`) From 62d2f581658289997ebcc02a5803282325eec35b Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 7 Feb 2023 14:01:58 -0800 Subject: [PATCH 11/19] Fix import --- pandas/core/arrays/string_arrow.py | 2 +- pandas/tests/strings/conftest.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 9ad92471b98b4..3da7143b03566 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -60,7 +60,7 @@ def _chk_pyarrow_available() -> None: # fallback for the ones that pyarrow doesn't yet support -class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin): +class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringArray): """ Extension array for string data in a ``pyarrow.ChunkedArray``. diff --git a/pandas/tests/strings/conftest.py b/pandas/tests/strings/conftest.py index cdc2b876194e6..3e1ee89e9a841 100644 --- a/pandas/tests/strings/conftest.py +++ b/pandas/tests/strings/conftest.py @@ -2,7 +2,7 @@ import pytest from pandas import Series -from pandas.core import strings +from pandas.core.strings.accessor import StringMethods _any_string_method = [ ("cat", (), {"sep": ","}), @@ -89,9 +89,7 @@ ) ) ids, _, _ = zip(*_any_string_method) # use method name as fixture-id -missing_methods = { - f for f in dir(strings.StringMethods) if not f.startswith("_") -} - set(ids) +missing_methods = {f for f in dir(StringMethods) if not f.startswith("_")} - set(ids) # test that the above list captures all methods of StringMethods assert not missing_methods From 64552146ac828e6ebabe3c9ef60d021a5e195898 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 7 Feb 2023 14:28:11 -0800 Subject: [PATCH 12/19] Define len --- pandas/core/arrays/string_arrow.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 3da7143b03566..4d2b39ec61fca 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -117,6 +117,16 @@ def __init__(self, values) -> None: "ArrowStringArray requires a PyArrow (chunked) array of string type" ) + def __len__(self) -> int: + """ + Length of this array. + + Returns + ------- + length : int + """ + return len(self._data) + @classmethod def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False): from pandas.core.arrays.masked import BaseMaskedArray From b57f142e3273d0d82a732ddf889baf44c0b5a643 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 8 Feb 2023 14:34:45 -0800 Subject: [PATCH 13/19] Address some comments --- pandas/core/groupby/ops.py | 3 ++- pandas/tests/strings/test_strings.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 185ea1aa258c3..381a6ffbda87f 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -96,6 +96,7 @@ ) if TYPE_CHECKING: + from pandas.core.arrays.arrow.dtype import ArrowDtype from pandas.core.arrays.string_ import StringDtype from pandas.core.generic import NDFrame @@ -404,7 +405,7 @@ def _reconstruct_ea_result( """ Construct an ExtensionArray result from an ndarray result. """ - dtype: BaseMaskedDtype | StringDtype + dtype: BaseMaskedDtype | StringDtype | ArrowDtype if is_string_dtype(values.dtype): dtype = values.dtype diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 4a296f197e39f..c4ff608b9a548 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -13,6 +13,7 @@ Series, ) import pandas._testing as tm +from pandas.core.strings.accessor import StringMethods @pytest.mark.parametrize("pattern", [0, True, Series(["foo", "bar"])]) @@ -599,8 +600,6 @@ def test_normalize_index(): ], ) def test_index_str_accessor_visibility(values, inferred_type, index_or_series): - from pandas.core.strings.accessor import StringMethods - obj = index_or_series(values) if index_or_series is Index: assert obj.inferred_type == inferred_type From 852920d9f0a1c26092ebb54db12d1f81ea1d05d3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 8 Feb 2023 14:52:06 -0800 Subject: [PATCH 14/19] address more dask tests --- pandas/tests/test_downstream.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/pandas/tests/test_downstream.py b/pandas/tests/test_downstream.py index a38850b80565d..ed627e37723d3 100644 --- a/pandas/tests/test_downstream.py +++ b/pandas/tests/test_downstream.py @@ -45,7 +45,7 @@ def test_dask(df, request): toolz = import_module("toolz") # noqa:F841 dask = import_module("dask") # noqa:F841 - if Version(dask.__version__) < Version("2023.1.1"): + if Version(dask.__version__) <= Version("2023.1.1"): request.node.add_marker( pytest.mark.xfail( reason="Used pandas.core.strings.StringMethods which moved", @@ -62,16 +62,33 @@ def test_dask(df, request): pd.set_option("compute.use_numexpr", olduse) -def test_dask_ufunc(): +def test_dask_ufunc(request): # dask sets "compute.use_numexpr" to False, so catch the current value # and ensure to reset it afterwards to avoid impacting other tests olduse = pd.get_option("compute.use_numexpr") try: dask = import_module("dask") # noqa:F841 + + if Version(dask.__version__) <= Version("2023.1.1"): + request.node.add_marker( + pytest.mark.xfail( + reason="Used pandas.core.strings.StringMethods which moved", + raises=AttributeError, + ) + ) + import dask.array as da import dask.dataframe as dd + if Version(dask.__version__) < Version("2023.1.1"): + request.node.add_marker( + pytest.mark.xfail( + reason="Used pandas.core.strings.StringMethods which moved", + raises=AttributeError, + ) + ) + s = Series([1.5, 2.3, 3.7, 4.0]) ds = dd.from_pandas(s, npartitions=2) @@ -83,9 +100,19 @@ def test_dask_ufunc(): @td.skip_if_no("dask") -def test_construct_dask_float_array_int_dtype_match_ndarray(): +def test_construct_dask_float_array_int_dtype_match_ndarray(request): # GH#40110 make sure we treat a float-dtype dask array with the same # rules we would for an ndarray + + dask = pytest.importorskip("dask") + if Version(dask.__version__) <= Version("2023.1.1"): + request.node.add_marker( + pytest.mark.xfail( + reason="Used pandas.core.strings.StringMethods which moved", + raises=AttributeError, + ) + ) + import dask.dataframe as dd arr = np.array([1, 2.5, 3]) From 53d20838bc6308b55c254f5fc37797497f7be4ec Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 8 Feb 2023 15:53:35 -0800 Subject: [PATCH 15/19] Revert groupby change --- pandas/core/groupby/ops.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 381a6ffbda87f..c2e3eb49723ec 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -56,7 +56,6 @@ is_numeric_dtype, is_period_dtype, is_sparse, - is_string_dtype, is_timedelta64_dtype, needs_i8_conversion, ) @@ -77,6 +76,7 @@ BaseMaskedArray, BaseMaskedDtype, ) +from pandas.core.arrays.string_ import StringDtype from pandas.core.frame import DataFrame from pandas.core.groupby import grouper from pandas.core.indexes.api import ( @@ -96,8 +96,6 @@ ) if TYPE_CHECKING: - from pandas.core.arrays.arrow.dtype import ArrowDtype - from pandas.core.arrays.string_ import StringDtype from pandas.core.generic import NDFrame @@ -389,7 +387,7 @@ def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray: # All of the functions implemented here are ordinal, so we can # operate on the tz-naive equivalents npvalues = values._ndarray.view("M8[ns]") - elif is_string_dtype(values.dtype): + elif isinstance(values.dtype, StringDtype): # StringArray npvalues = values.to_numpy(object, na_value=np.nan) else: @@ -405,9 +403,9 @@ def _reconstruct_ea_result( """ Construct an ExtensionArray result from an ndarray result. """ - dtype: BaseMaskedDtype | StringDtype | ArrowDtype + dtype: BaseMaskedDtype | StringDtype - if is_string_dtype(values.dtype): + if isinstance(values.dtype, StringDtype): dtype = values.dtype string_array_cls = dtype.construct_array_type() return string_array_cls._from_sequence(res_values, dtype=dtype) From 977b698bb43b629cd9c53de7fdb8c2ab5d0f062d Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 8 Feb 2023 15:59:12 -0800 Subject: [PATCH 16/19] fix some tests --- pandas/tests/extension/test_arrow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index ff210fcc8276a..03691232c5eb4 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1804,7 +1804,7 @@ def test_str_join(): [ [None, 2, None, ["ab", None]], [None, 2, 1, ["ab", None]], - [None, 2, 1, ["bc", None]], + [1, 3, 1, ["bc", None]], ], ) def test_str_slice(start, stop, step, exp): @@ -1879,7 +1879,7 @@ def test_str_transform_functions(method, exp): def test_str_len(): ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) result = ser.str.len() - expected = pd.Series([4, None], dtype=ArrowDtype(pa.int64())) + expected = pd.Series([4, None], dtype=ArrowDtype(pa.int32())) tm.assert_series_equal(result, expected) From 5bb7d2a806ad134a75f4b815413a3f21b33353e8 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 9 Feb 2023 14:31:28 -0800 Subject: [PATCH 17/19] Typing --- pandas/core/arrays/arrow/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index fb1879fee8dc4..05ae3cb459af2 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1680,7 +1680,7 @@ def _str_rstrip(self, to_strip=None): result = pc.utf8_rtrim(self._data, characters=to_strip) return type(self)(result) - def _str_removeprefix(self, prefix: str) -> Series: + def _str_removeprefix(self, prefix: str): raise NotImplementedError("str.removeprefix not supported.") # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed # starts_with = pc.starts_with(self._data, pattern=prefix) @@ -1688,7 +1688,7 @@ def _str_removeprefix(self, prefix: str) -> Series: # result = pc.if_else(starts_with, removed, self._data) # return type(self)(result) - def _str_removesuffix(self, suffix: str) -> Series: + def _str_removesuffix(self, suffix: str): ends_with = pc.ends_with(self._data, pattern=suffix) removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix)) result = pc.if_else(ends_with, removed, self._data) From 5d7ff0c79476b49a102e686bfbb98d0f1f651098 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 14 Feb 2023 13:09:29 -0800 Subject: [PATCH 18/19] Undo downstream changes --- pandas/tests/test_downstream.py | 49 ++------------------------------- 1 file changed, 3 insertions(+), 46 deletions(-) diff --git a/pandas/tests/test_downstream.py b/pandas/tests/test_downstream.py index ed627e37723d3..b17dce234043c 100644 --- a/pandas/tests/test_downstream.py +++ b/pandas/tests/test_downstream.py @@ -17,7 +17,6 @@ Series, ) import pandas._testing as tm -from pandas.util.version import Version def import_module(name): @@ -35,8 +34,7 @@ def df(): return DataFrame({"A": [1, 2, 3]}) -def test_dask(df, request): - +def test_dask(df): # dask sets "compute.use_numexpr" to False, so catch the current value # and ensure to reset it afterwards to avoid impacting other tests olduse = pd.get_option("compute.use_numexpr") @@ -45,14 +43,6 @@ def test_dask(df, request): toolz = import_module("toolz") # noqa:F841 dask = import_module("dask") # noqa:F841 - if Version(dask.__version__) <= Version("2023.1.1"): - request.node.add_marker( - pytest.mark.xfail( - reason="Used pandas.core.strings.StringMethods which moved", - raises=AttributeError, - ) - ) - import dask.dataframe as dd ddf = dd.from_pandas(df, npartitions=3) @@ -62,33 +52,16 @@ def test_dask(df, request): pd.set_option("compute.use_numexpr", olduse) -def test_dask_ufunc(request): +def test_dask_ufunc(): # dask sets "compute.use_numexpr" to False, so catch the current value # and ensure to reset it afterwards to avoid impacting other tests olduse = pd.get_option("compute.use_numexpr") try: dask = import_module("dask") # noqa:F841 - - if Version(dask.__version__) <= Version("2023.1.1"): - request.node.add_marker( - pytest.mark.xfail( - reason="Used pandas.core.strings.StringMethods which moved", - raises=AttributeError, - ) - ) - import dask.array as da import dask.dataframe as dd - if Version(dask.__version__) < Version("2023.1.1"): - request.node.add_marker( - pytest.mark.xfail( - reason="Used pandas.core.strings.StringMethods which moved", - raises=AttributeError, - ) - ) - s = Series([1.5, 2.3, 3.7, 4.0]) ds = dd.from_pandas(s, npartitions=2) @@ -100,19 +73,9 @@ def test_dask_ufunc(request): @td.skip_if_no("dask") -def test_construct_dask_float_array_int_dtype_match_ndarray(request): +def test_construct_dask_float_array_int_dtype_match_ndarray(): # GH#40110 make sure we treat a float-dtype dask array with the same # rules we would for an ndarray - - dask = pytest.importorskip("dask") - if Version(dask.__version__) <= Version("2023.1.1"): - request.node.add_marker( - pytest.mark.xfail( - reason="Used pandas.core.strings.StringMethods which moved", - raises=AttributeError, - ) - ) - import dask.dataframe as dd arr = np.array([1, 2.5, 3]) @@ -137,7 +100,6 @@ def test_construct_dask_float_array_int_dtype_match_ndarray(request): def test_xarray(df): - xarray = import_module("xarray") # noqa:F841 assert df.to_xarray() is not None @@ -180,7 +142,6 @@ def test_oo_optimized_datetime_index_unpickle(): @pytest.mark.network @tm.network def test_statsmodels(): - statsmodels = import_module("statsmodels") # noqa:F841 import statsmodels.api as sm import statsmodels.formula.api as smf @@ -190,7 +151,6 @@ def test_statsmodels(): def test_scikit_learn(): - sklearn = import_module("sklearn") # noqa:F841 from sklearn import ( datasets, @@ -206,7 +166,6 @@ def test_scikit_learn(): @pytest.mark.network @tm.network def test_seaborn(): - seaborn = import_module("seaborn") tips = seaborn.load_dataset("tips") seaborn.stripplot(x="day", y="total_bill", data=tips) @@ -226,13 +185,11 @@ def test_pandas_gbq(): "variable or through the environmental variable QUANDL_API_KEY", ) def test_pandas_datareader(): - pandas_datareader = import_module("pandas_datareader") pandas_datareader.DataReader("F", "quandl", "2017-01-01", "2017-02-01") def test_pyarrow(df): - pyarrow = import_module("pyarrow") table = pyarrow.Table.from_pandas(df) result = table.to_pandas() From 35e35e3358cb55e44150d316b451d45656deb452 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 15 Feb 2023 13:15:27 -0800 Subject: [PATCH 19/19] Improve error message --- pandas/core/arrays/arrow/array.py | 64 +++++++++++++++++++++------- pandas/tests/extension/test_arrow.py | 4 +- 2 files changed, 51 insertions(+), 17 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 607593d3bfadc..df8d6172dd6f6 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1600,10 +1600,14 @@ def _str_join(self, sep: str): return type(self)(pc.binary_join(self._data, sep)) def _str_partition(self, sep: str, expand: bool): - raise NotImplementedError("str.partition not supported.") + raise NotImplementedError( + "str.partition not supported with pd.ArrowDtype(pa.string())." + ) def _str_rpartition(self, sep: str, expand: bool): - raise NotImplementedError("str.rpartition not supported.") + raise NotImplementedError( + "str.rpartition not supported with pd.ArrowDtype(pa.string())." + ) def _str_slice( self, start: int | None = None, stop: int | None = None, step: int | None = None @@ -1692,7 +1696,9 @@ def _str_rstrip(self, to_strip=None): return type(self)(result) def _str_removeprefix(self, prefix: str): - raise NotImplementedError("str.removeprefix not supported.") + raise NotImplementedError( + "str.removeprefix not supported with pd.ArrowDtype(pa.string())." + ) # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed # starts_with = pc.starts_with(self._data, pattern=prefix) # removed = pc.utf8_slice_codeunits(self._data, len(prefix)) @@ -1706,45 +1712,71 @@ def _str_removesuffix(self, suffix: str): return type(self)(result) def _str_casefold(self): - raise NotImplementedError("str.casefold not supported.") + raise NotImplementedError( + "str.casefold not supported with pd.ArrowDtype(pa.string())." + ) def _str_encode(self, encoding, errors: str = "strict"): - raise NotImplementedError("str.encode not supported.") + raise NotImplementedError( + "str.encode not supported with pd.ArrowDtype(pa.string())." + ) def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): - raise NotImplementedError("str.extract not supported.") + raise NotImplementedError( + "str.extract not supported with pd.ArrowDtype(pa.string())." + ) def _str_findall(self, pat, flags: int = 0): - raise NotImplementedError("str.findall not supported.") + raise NotImplementedError( + "str.findall not supported with pd.ArrowDtype(pa.string())." + ) def _str_get_dummies(self, sep: str = "|"): - raise NotImplementedError("str.get_dummies not supported.") + raise NotImplementedError( + "str.get_dummies not supported with pd.ArrowDtype(pa.string())." + ) def _str_index(self, sub, start: int = 0, end=None): - raise NotImplementedError("str.index not supported.") + raise NotImplementedError( + "str.index not supported with pd.ArrowDtype(pa.string())." + ) def _str_rindex(self, sub, start: int = 0, end=None): - raise NotImplementedError("str.rindex not supported.") + raise NotImplementedError( + "str.rindex not supported with pd.ArrowDtype(pa.string())." + ) def _str_normalize(self, form): - raise NotImplementedError("str.normalize not supported.") + raise NotImplementedError( + "str.normalize not supported with pd.ArrowDtype(pa.string())." + ) def _str_rfind(self, sub, start: int = 0, end=None): - raise NotImplementedError("str.rfind not supported.") + raise NotImplementedError( + "str.rfind not supported with pd.ArrowDtype(pa.string())." + ) def _str_split( self, pat=None, n=-1, expand: bool = False, regex: bool | None = None ): - raise NotImplementedError("str.split not supported.") + raise NotImplementedError( + "str.split not supported with pd.ArrowDtype(pa.string())." + ) def _str_rsplit(self, pat=None, n=-1): - raise NotImplementedError("str.rsplit not supported.") + raise NotImplementedError( + "str.rsplit not supported with pd.ArrowDtype(pa.string())." + ) def _str_translate(self, table): - raise NotImplementedError("str.translate not supported.") + raise NotImplementedError( + "str.translate not supported with pd.ArrowDtype(pa.string())." + ) def _str_wrap(self, width, **kwargs): - raise NotImplementedError("str.wrap not supported.") + raise NotImplementedError( + "str.wrap not supported with pd.ArrowDtype(pa.string())." + ) @property def _dt_day(self): diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index a8124488f6e3e..50fb636c2beb8 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1925,7 +1925,9 @@ def test_str_removesuffix(val): ) def test_str_unsupported_methods(method, args): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) - with pytest.raises(NotImplementedError, match=f"str.{method} not supported."): + with pytest.raises( + NotImplementedError, match=f"str.{method} not supported with pd.ArrowDtype" + ): getattr(ser.str, method)(*args)