diff --git a/pandas/conftest.py b/pandas/conftest.py index 3fdde3261bd68..35affa62ccf68 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -1144,6 +1144,8 @@ def nullable_string_dtype(request): * 'string' * 'arrow_string' """ + from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401 + return request.param diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index faca868873efa..fd47597b2191f 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -24,6 +24,10 @@ from pandas.util._validators import validate_fillna_kwargs from pandas.core.dtypes.base import ExtensionDtype +from pandas.core.dtypes.common import ( + is_object_dtype, + is_string_dtype, +) from pandas.core.dtypes.dtypes import register_extension_dtype from pandas.core.dtypes.missing import isna @@ -41,6 +45,7 @@ check_array_indexer, validate_indices, ) +from pandas.core.strings.object_array import ObjectStringArrayMixin try: import pyarrow as pa @@ -149,7 +154,12 @@ def __eq__(self, other) -> bool: return False -class ArrowStringArray(OpsMixin, ExtensionArray): +# TODO: Inherit directly from BaseStringArrayMethods. Currently we inherit from +# ObjectStringArrayMixin because we want to have the object-dtype based methods as +# fallback for the ones that pyarrow doesn't yet support + + +class ArrowStringArray(OpsMixin, ExtensionArray, ObjectStringArrayMixin): """ Extension array for string data in a ``pyarrow.ChunkedArray``. @@ -676,3 +686,71 @@ def value_counts(self, dropna: bool = True) -> Series: raise NotImplementedError("yo") return Series(counts, index=index).astype("Int64") + + # ------------------------------------------------------------------------ + # String methods interface + + _str_na_value = ArrowStringDtype.na_value + + def _str_map(self, f, na_value=None, dtype: Dtype | None = None): + # TODO: de-duplicate with StringArray method. This method is moreless copy and + # paste. + + from pandas.arrays import ( + BooleanArray, + IntegerArray, + ) + + if dtype is None: + dtype = self.dtype + if na_value is None: + na_value = self.dtype.na_value + + mask = isna(self) + arr = np.asarray(self) + + if is_integer_dtype(dtype) or is_bool_dtype(dtype): + constructor: type[IntegerArray] | type[BooleanArray] + if is_integer_dtype(dtype): + constructor = IntegerArray + else: + constructor = BooleanArray + + na_value_is_na = isna(na_value) + if na_value_is_na: + na_value = 1 + result = lib.map_infer_mask( + arr, + f, + mask.view("uint8"), + convert=False, + na_value=na_value, + # error: Value of type variable "_DTypeScalar" of "dtype" cannot be + # "object" + # error: Argument 1 to "dtype" has incompatible type + # "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected + # "Type[object]" + dtype=np.dtype(dtype), # type: ignore[type-var,arg-type] + ) + + if not na_value_is_na: + mask[:] = False + + # error: Argument 1 to "IntegerArray" has incompatible type + # "Union[ExtensionArray, ndarray]"; expected "ndarray" + # error: Argument 1 to "BooleanArray" has incompatible type + # "Union[ExtensionArray, ndarray]"; expected "ndarray" + return constructor(result, mask) # type: ignore[arg-type] + + elif is_string_dtype(dtype) and not is_object_dtype(dtype): + # i.e. StringDtype + result = lib.map_infer_mask( + arr, f, mask.view("uint8"), convert=False, na_value=na_value + ) + return self._from_sequence(result) + else: + # This is when the result type is object. We reach this when + # -> We know the result type is truly object (e.g. .encode returns bytes + # or .findall returns a list). + # -> We don't know the result type. E.g. `.get` can return anything. + return lib.map_infer_mask(arr, f, mask.view("uint8")) diff --git a/pandas/core/strings/__init__.py b/pandas/core/strings/__init__.py index 943686fc85a05..28aba7c9ce0b3 100644 --- a/pandas/core/strings/__init__.py +++ b/pandas/core/strings/__init__.py @@ -25,6 +25,7 @@ # - StringArray # - PandasArray # - Categorical +# - ArrowStringArray from pandas.core.strings.accessor import StringMethods from pandas.core.strings.base import BaseStringArrayMethods diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 85da954ec842e..0b5613e302175 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -154,10 +154,11 @@ class StringMethods(NoNewAttributesMixin): def __init__(self, data): from pandas.core.arrays.string_ import StringDtype + from pandas.core.arrays.string_arrow import ArrowStringDtype self._inferred_dtype = self._validate(data) self._is_categorical = is_categorical_dtype(data.dtype) - self._is_string = isinstance(data.dtype, StringDtype) + self._is_string = isinstance(data.dtype, (StringDtype, ArrowStringDtype)) self._data = data self._index = self._name = None @@ -316,7 +317,7 @@ def cons_row(x): # This is a mess. dtype: Optional[str] if self._is_string and returns_string: - dtype = "string" + dtype = self._orig.dtype else: dtype = None diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 8b84a510c01e6..749f3d0aee8a5 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -91,17 +91,11 @@ def test_setitem_with_scalar_string(dtype): @pytest.mark.parametrize( "input, method", [ - (["a", "b", "c"], operator.methodcaller("capitalize")), (["a", "b", "c"], operator.methodcaller("capitalize")), (["a b", "a bc. de"], operator.methodcaller("capitalize")), ], ) -def test_string_methods(input, method, dtype, request): - if dtype == "arrow_string": - reason = "AttributeError: 'ArrowStringDtype' object has no attribute 'base'" - mark = pytest.mark.xfail(reason=reason) - request.node.add_marker(mark) - +def test_string_methods(input, method, dtype): a = pd.Series(input, dtype=dtype) b = pd.Series(input, dtype="object") result = method(a.str) diff --git a/pandas/tests/strings/test_string_array.py b/pandas/tests/strings/test_string_array.py index b51132caf7573..23c9b14c5a36a 100644 --- a/pandas/tests/strings/test_string_array.py +++ b/pandas/tests/strings/test_string_array.py @@ -11,14 +11,22 @@ ) -def test_string_array(any_string_method): +def test_string_array(nullable_string_dtype, any_string_method, request): method_name, args, kwargs = any_string_method if method_name == "decode": pytest.skip("decode requires bytes.") + if nullable_string_dtype == "arrow_string" and method_name in { + "extract", + "extractall", + }: + reason = "extract/extractall does not yet dispatch to array" + mark = pytest.mark.xfail(reason=reason) + request.node.add_marker(mark) + data = ["a", "bb", np.nan, "ccc"] a = Series(data, dtype=object) - b = Series(data, dtype="string") + b = Series(data, dtype=nullable_string_dtype) expected = getattr(a.str, method_name)(*args, **kwargs) result = getattr(b.str, method_name)(*args, **kwargs) @@ -27,7 +35,7 @@ def test_string_array(any_string_method): if expected.dtype == "object" and lib.is_string_array( expected.dropna().values, ): - assert result.dtype == "string" + assert result.dtype == nullable_string_dtype result = result.astype(object) elif expected.dtype == "object" and lib.is_bool_array( @@ -46,7 +54,7 @@ def test_string_array(any_string_method): elif isinstance(expected, DataFrame): columns = expected.select_dtypes(include="object").columns - assert all(result[columns].dtypes == "string") + assert all(result[columns].dtypes == nullable_string_dtype) result[columns] = result[columns].astype(object) tm.assert_equal(result, expected) @@ -60,8 +68,8 @@ def test_string_array(any_string_method): ("rindex", [2, None]), ], ) -def test_string_array_numeric_integer_array(method, expected): - s = Series(["aba", None], dtype="string") +def test_string_array_numeric_integer_array(nullable_string_dtype, method, expected): + s = Series(["aba", None], dtype=nullable_string_dtype) result = getattr(s.str, method)("a") expected = Series(expected, dtype="Int64") tm.assert_series_equal(result, expected) @@ -73,33 +81,39 @@ def test_string_array_numeric_integer_array(method, expected): ("isdigit", [False, None, True]), ("isalpha", [True, None, False]), ("isalnum", [True, None, True]), - ("isdigit", [False, None, True]), + ("isnumeric", [False, None, True]), ], ) -def test_string_array_boolean_array(method, expected): - s = Series(["a", None, "1"], dtype="string") +def test_string_array_boolean_array(nullable_string_dtype, method, expected): + s = Series(["a", None, "1"], dtype=nullable_string_dtype) result = getattr(s.str, method)() expected = Series(expected, dtype="boolean") tm.assert_series_equal(result, expected) -def test_string_array_extract(): +def test_string_array_extract(nullable_string_dtype, request): # https://github.com/pandas-dev/pandas/issues/30969 # Only expand=False & multiple groups was failing - a = Series(["a1", "b2", "cc"], dtype="string") + + if nullable_string_dtype == "arrow_string": + reason = "extract does not yet dispatch to array" + mark = pytest.mark.xfail(reason=reason) + request.node.add_marker(mark) + + a = Series(["a1", "b2", "cc"], dtype=nullable_string_dtype) b = Series(["a1", "b2", "cc"], dtype="object") pat = r"(\w)(\d)" result = a.str.extract(pat, expand=False) expected = b.str.extract(pat, expand=False) - assert all(result.dtypes == "string") + assert all(result.dtypes == nullable_string_dtype) result = result.astype(object) tm.assert_equal(result, expected) -def test_str_get_stringarray_multiple_nans(): - s = Series(pd.array(["a", "ab", pd.NA, "abc"])) +def test_str_get_stringarray_multiple_nans(nullable_string_dtype): + s = Series(pd.array(["a", "ab", pd.NA, "abc"], dtype=nullable_string_dtype)) result = s.str.get(2) - expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"])) + expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"], dtype=nullable_string_dtype)) tm.assert_series_equal(result, expected)