Skip to content

ENH: [ArrowStringArray] Enable the string methods for the arrow-backed StringArray #40708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,8 @@ def nullable_string_dtype(request):
* 'string'
* 'arrow_string'
"""
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

return request.param


Expand Down
72 changes: 71 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from pandas.util._validators import validate_fillna_kwargs

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.common import (
is_object_dtype,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import register_extension_dtype
from pandas.core.dtypes.missing import isna

Expand All @@ -41,6 +45,7 @@
check_array_indexer,
validate_indices,
)
from pandas.core.strings.object_array import ObjectStringArrayMixin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the goal is to inherit from BaseStringArrayMethods. (I recall this being mentioned somewhere). For now use ObjectStringArrayMixin similar to fletcher xhochy/fletcher#196

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because you want to have the object-dtype based methods as fallback for the ones that pyarrow doesn't yet support, I suppose?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. This PR just gets the string methods working for the existing tests we have StringArray. So just converting to object and not using native pyarrow functions yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a small comment about that (eg below where the class is created) about why ObjectStringArrayMixin is mixed in?


try:
import pyarrow as pa
Expand Down Expand Up @@ -149,7 +154,7 @@ def __eq__(self, other) -> bool:
return False


class ArrowStringArray(OpsMixin, ExtensionArray):
class ArrowStringArray(OpsMixin, ExtensionArray, ObjectStringArrayMixin):
"""
Extension array for string data in a ``pyarrow.ChunkedArray``.

Expand Down Expand Up @@ -676,3 +681,68 @@ def value_counts(self, dropna: bool = True) -> Series:
raise NotImplementedError("yo")

return Series(counts, index=index).astype("Int64")

# ------------------------------------------------------------------------
# String methods interface

_str_na_value = ArrowStringDtype.na_value

def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moreless cut and paste from StringArray.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be shared? (eg move it a common helper function or mixin?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed. I think the de-duplication is better as an immediate follow-up to keep changes to existing code paths (i.e. StringArray) in a separate PR and keep this one scoped to just additions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either is fine for me

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eg move it a common helper function or mixin

or have a common base class. #35169 (comment)

so better as a follow-on to allow for more discussion

from pandas.arrays import (
BooleanArray,
IntegerArray,
)

if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
constructor: type[IntegerArray] | type[BooleanArray]
if is_integer_dtype(dtype):
constructor = IntegerArray
else:
constructor = BooleanArray

na_value_is_na = isna(na_value)
if na_value_is_na:
na_value = 1
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
# error: Value of type variable "_DTypeScalar" of "dtype" cannot be
# "object"
# error: Argument 1 to "dtype" has incompatible type
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
# "Type[object]"
dtype=np.dtype(dtype), # type: ignore[type-var,arg-type]
)

if not na_value_is_na:
mask[:] = False

# error: Argument 1 to "IntegerArray" has incompatible type
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
# error: Argument 1 to "BooleanArray" has incompatible type
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
return constructor(result, mask) # type: ignore[arg-type]

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
return self._from_sequence(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))
1 change: 1 addition & 0 deletions pandas/core/strings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# - StringArray
# - PandasArray
# - Categorical
# - ArrowStringArray

from pandas.core.strings.accessor import StringMethods
from pandas.core.strings.base import BaseStringArrayMethods
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ class StringMethods(NoNewAttributesMixin):

def __init__(self, data):
from pandas.core.arrays.string_ import StringDtype
from pandas.core.arrays.string_arrow import ArrowStringDtype

self._inferred_dtype = self._validate(data)
self._is_categorical = is_categorical_dtype(data.dtype)
self._is_string = isinstance(data.dtype, StringDtype)
self._is_string = isinstance(data.dtype, (StringDtype, ArrowStringDtype))
self._data = data

self._index = self._name = None
Expand Down Expand Up @@ -316,7 +317,7 @@ def cons_row(x):
# This is a mess.
dtype: Optional[str]
if self._is_string and returns_string:
dtype = "string"
dtype = self._orig.dtype
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

special case here for partition/rpartition/split where the array method returns an object array

else:
dtype = None

Expand Down
8 changes: 1 addition & 7 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,11 @@ def test_setitem_with_scalar_string(dtype):
@pytest.mark.parametrize(
"input, method",
[
(["a", "b", "c"], operator.methodcaller("capitalize")),
(["a", "b", "c"], operator.methodcaller("capitalize")),
(["a b", "a bc. de"], operator.methodcaller("capitalize")),
],
)
def test_string_methods(input, method, dtype, request):
if dtype == "arrow_string":
reason = "AttributeError: 'ArrowStringDtype' object has no attribute 'base'"
mark = pytest.mark.xfail(reason=reason)
request.node.add_marker(mark)

def test_string_methods(input, method, dtype):
a = pd.Series(input, dtype=dtype)
b = pd.Series(input, dtype="object")
result = method(a.str)
Expand Down
44 changes: 29 additions & 15 deletions pandas/tests/strings/test_string_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@
)


def test_string_array(any_string_method):
def test_string_array(nullable_string_dtype, any_string_method, request):
method_name, args, kwargs = any_string_method
if method_name == "decode":
pytest.skip("decode requires bytes.")

if nullable_string_dtype == "arrow_string" and method_name in {
"extract",
"extractall",
}:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could either special case, refactor to dispatch to array or leave as follow-up and xfail for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the xfail is fine for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. keeps the scope of this PR limited.

reason = "extract/extractall does not yet dispatch to array"
mark = pytest.mark.xfail(reason=reason)
request.node.add_marker(mark)

data = ["a", "bb", np.nan, "ccc"]
a = Series(data, dtype=object)
b = Series(data, dtype="string")
b = Series(data, dtype=nullable_string_dtype)

expected = getattr(a.str, method_name)(*args, **kwargs)
result = getattr(b.str, method_name)(*args, **kwargs)
Expand All @@ -27,7 +35,7 @@ def test_string_array(any_string_method):
if expected.dtype == "object" and lib.is_string_array(
expected.dropna().values,
):
assert result.dtype == "string"
assert result.dtype == nullable_string_dtype
result = result.astype(object)

elif expected.dtype == "object" and lib.is_bool_array(
Expand All @@ -46,7 +54,7 @@ def test_string_array(any_string_method):

elif isinstance(expected, DataFrame):
columns = expected.select_dtypes(include="object").columns
assert all(result[columns].dtypes == "string")
assert all(result[columns].dtypes == nullable_string_dtype)
result[columns] = result[columns].astype(object)
tm.assert_equal(result, expected)

Expand All @@ -60,8 +68,8 @@ def test_string_array(any_string_method):
("rindex", [2, None]),
],
)
def test_string_array_numeric_integer_array(method, expected):
s = Series(["aba", None], dtype="string")
def test_string_array_numeric_integer_array(nullable_string_dtype, method, expected):
s = Series(["aba", None], dtype=nullable_string_dtype)
result = getattr(s.str, method)("a")
expected = Series(expected, dtype="Int64")
tm.assert_series_equal(result, expected)
Expand All @@ -73,33 +81,39 @@ def test_string_array_numeric_integer_array(method, expected):
("isdigit", [False, None, True]),
("isalpha", [True, None, False]),
("isalnum", [True, None, True]),
("isdigit", [False, None, True]),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another dup

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe can add "isnumeric" instead, which doesn't seem to be tested

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. "isnumeric" and other "is" functions are tested to some degree by _any_string_method fixture used in test_string_array

("isnumeric", [False, None, True]),
],
)
def test_string_array_boolean_array(method, expected):
s = Series(["a", None, "1"], dtype="string")
def test_string_array_boolean_array(nullable_string_dtype, method, expected):
s = Series(["a", None, "1"], dtype=nullable_string_dtype)
result = getattr(s.str, method)()
expected = Series(expected, dtype="boolean")
tm.assert_series_equal(result, expected)


def test_string_array_extract():
def test_string_array_extract(nullable_string_dtype, request):
# https://github.com/pandas-dev/pandas/issues/30969
# Only expand=False & multiple groups was failing
a = Series(["a1", "b2", "cc"], dtype="string")

if nullable_string_dtype == "arrow_string":
reason = "extract does not yet dispatch to array"
mark = pytest.mark.xfail(reason=reason)
request.node.add_marker(mark)

a = Series(["a1", "b2", "cc"], dtype=nullable_string_dtype)
b = Series(["a1", "b2", "cc"], dtype="object")
pat = r"(\w)(\d)"

result = a.str.extract(pat, expand=False)
expected = b.str.extract(pat, expand=False)
assert all(result.dtypes == "string")
assert all(result.dtypes == nullable_string_dtype)

result = result.astype(object)
tm.assert_equal(result, expected)


def test_str_get_stringarray_multiple_nans():
s = Series(pd.array(["a", "ab", pd.NA, "abc"]))
def test_str_get_stringarray_multiple_nans(nullable_string_dtype):
s = Series(pd.array(["a", "ab", pd.NA, "abc"], dtype=nullable_string_dtype))
result = s.str.get(2)
expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"]))
expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"], dtype=nullable_string_dtype))
tm.assert_series_equal(result, expected)