-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: [ArrowStringArray] Enable the string methods for the arrow-backed StringArray #40708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
9b6c054
c0fedcd
e603321
37035b4
c823953
76292ec
50db876
374924b
23b40e3
4b86f67
7bd82f9
aaf54ca
b1cf83d
36d0034
f479819
d1c8a3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,10 @@ | |
from pandas.util._validators import validate_fillna_kwargs | ||
|
||
from pandas.core.dtypes.base import ExtensionDtype | ||
from pandas.core.dtypes.common import ( | ||
is_object_dtype, | ||
is_string_dtype, | ||
) | ||
from pandas.core.dtypes.dtypes import register_extension_dtype | ||
from pandas.core.dtypes.missing import isna | ||
|
||
|
@@ -43,6 +47,7 @@ | |
check_array_indexer, | ||
validate_indices, | ||
) | ||
from pandas.core.strings.object_array import ObjectStringArrayMixin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the goal is to inherit from BaseStringArrayMethods. (I recall this being mentioned somewhere). For now use ObjectStringArrayMixin similar to fletcher xhochy/fletcher#196 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because you want to have the object-dtype based methods as fallback for the ones that pyarrow doesn't yet support, I suppose? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep. This PR just gets the string methods working for the existing tests we have StringArray. So just converting to object and not using native pyarrow functions yet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a small comment about that (eg below where the class is created) about why |
||
|
||
try: | ||
import pyarrow as pa | ||
|
@@ -153,7 +158,7 @@ def __eq__(self, other) -> bool: | |
return False | ||
|
||
|
||
class ArrowStringArray(OpsMixin, ExtensionArray): | ||
class ArrowStringArray(OpsMixin, ExtensionArray, ObjectStringArrayMixin): | ||
""" | ||
Extension array for string data in a ``pyarrow.ChunkedArray``. | ||
|
||
|
@@ -680,3 +685,64 @@ def value_counts(self, dropna: bool = True) -> Series: | |
raise NotImplementedError("yo") | ||
|
||
return Series(counts, index=index).astype("Int64") | ||
|
||
# ------------------------------------------------------------------------ | ||
# String methods interface | ||
|
||
_str_na_value = ArrowStringDtype.na_value | ||
|
||
def _str_map(self, f, na_value=None, dtype: Dtype | None = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moreless cut and paste from StringArray. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be shared? (eg move it a common helper function or mixin?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indeed. I think the de-duplication is better as an immediate follow-up to keep changes to existing code paths (i.e. StringArray) in a separate PR and keep this one scoped to just additions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. either is fine for me There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
or have a common base class. #35169 (comment) so better as a follow-on to allow for more discussion |
||
from pandas.arrays import ( | ||
BooleanArray, | ||
IntegerArray, | ||
) | ||
|
||
if dtype is None: | ||
dtype = self.dtype | ||
if na_value is None: | ||
na_value = self.dtype.na_value | ||
|
||
mask = isna(self) | ||
arr = np.asarray(self) | ||
|
||
if is_integer_dtype(dtype) or is_bool_dtype(dtype): | ||
constructor: type[IntegerArray] | type[BooleanArray] | ||
if is_integer_dtype(dtype): | ||
constructor = IntegerArray | ||
else: | ||
constructor = BooleanArray | ||
|
||
na_value_is_na = isna(na_value) | ||
if na_value_is_na: | ||
na_value = 1 | ||
result = lib.map_infer_mask( | ||
arr, | ||
f, | ||
mask.view("uint8"), | ||
convert=False, | ||
na_value=na_value, | ||
# error: Value of type variable "_DTypeScalar" of "dtype" cannot be | ||
# "object" | ||
# error: Argument 1 to "dtype" has incompatible type | ||
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected | ||
# "Type[object]" | ||
dtype=np.dtype(dtype), # type: ignore[type-var,arg-type] | ||
) | ||
|
||
if not na_value_is_na: | ||
mask[:] = False | ||
|
||
return constructor(result, mask) | ||
|
||
elif is_string_dtype(dtype) and not is_object_dtype(dtype): | ||
# i.e. StringDtype | ||
result = lib.map_infer_mask( | ||
arr, f, mask.view("uint8"), convert=False, na_value=na_value | ||
) | ||
return self._from_sequence(result) | ||
else: | ||
# This is when the result type is object. We reach this when | ||
# -> We know the result type is truly object (e.g. .encode returns bytes | ||
# or .findall returns a list). | ||
# -> We don't know the result type. E.g. `.get` can return anything. | ||
return lib.map_infer_mask(arr, f, mask.view("uint8")) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -154,10 +154,11 @@ class StringMethods(NoNewAttributesMixin): | |
|
||
def __init__(self, data): | ||
from pandas.core.arrays.string_ import StringDtype | ||
from pandas.core.arrays.string_arrow import ArrowStringDtype | ||
|
||
self._inferred_dtype = self._validate(data) | ||
self._is_categorical = is_categorical_dtype(data.dtype) | ||
self._is_string = isinstance(data.dtype, StringDtype) | ||
self._is_string = isinstance(data.dtype, (StringDtype, ArrowStringDtype)) | ||
self._data = data | ||
|
||
self._index = self._name = None | ||
|
@@ -316,7 +317,7 @@ def cons_row(x): | |
# This is a mess. | ||
dtype: Optional[str] | ||
if self._is_string and returns_string: | ||
dtype = "string" | ||
dtype = self._orig.dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. special case here for partition/rpartition/split where the array method returns an object array |
||
else: | ||
dtype = None | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,14 +11,22 @@ | |
) | ||
|
||
|
||
def test_string_array(any_string_method): | ||
def test_string_array(nullable_string_dtype, any_string_method, request): | ||
method_name, args, kwargs = any_string_method | ||
if method_name == "decode": | ||
pytest.skip("decode requires bytes.") | ||
|
||
if nullable_string_dtype == "arrow_string" and method_name in { | ||
"extract", | ||
"extractall", | ||
}: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could either special case, refactor to dispatch to array or leave as follow-up and xfail for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the xfail is fine for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed. keeps the scope of this PR limited. |
||
reason = "extract/extractall does not yet dispatch to array" | ||
mark = pytest.mark.xfail(reason=reason) | ||
request.node.add_marker(mark) | ||
|
||
data = ["a", "bb", np.nan, "ccc"] | ||
a = Series(data, dtype=object) | ||
b = Series(data, dtype="string") | ||
b = Series(data, dtype=nullable_string_dtype) | ||
|
||
expected = getattr(a.str, method_name)(*args, **kwargs) | ||
result = getattr(b.str, method_name)(*args, **kwargs) | ||
|
@@ -27,7 +35,7 @@ def test_string_array(any_string_method): | |
if expected.dtype == "object" and lib.is_string_array( | ||
expected.dropna().values, | ||
): | ||
assert result.dtype == "string" | ||
assert result.dtype == nullable_string_dtype | ||
result = result.astype(object) | ||
|
||
elif expected.dtype == "object" and lib.is_bool_array( | ||
|
@@ -46,7 +54,7 @@ def test_string_array(any_string_method): | |
|
||
elif isinstance(expected, DataFrame): | ||
columns = expected.select_dtypes(include="object").columns | ||
assert all(result[columns].dtypes == "string") | ||
assert all(result[columns].dtypes == nullable_string_dtype) | ||
result[columns] = result[columns].astype(object) | ||
tm.assert_equal(result, expected) | ||
|
||
|
@@ -60,8 +68,8 @@ def test_string_array(any_string_method): | |
("rindex", [2, None]), | ||
], | ||
) | ||
def test_string_array_numeric_integer_array(method, expected): | ||
s = Series(["aba", None], dtype="string") | ||
def test_string_array_numeric_integer_array(nullable_string_dtype, method, expected): | ||
s = Series(["aba", None], dtype=nullable_string_dtype) | ||
result = getattr(s.str, method)("a") | ||
expected = Series(expected, dtype="Int64") | ||
tm.assert_series_equal(result, expected) | ||
|
@@ -73,33 +81,38 @@ def test_string_array_numeric_integer_array(method, expected): | |
("isdigit", [False, None, True]), | ||
("isalpha", [True, None, False]), | ||
("isalnum", [True, None, True]), | ||
("isdigit", [False, None, True]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another dup There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe can add "isnumeric" instead, which doesn't seem to be tested There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure. "isnumeric" and other "is" functions are tested to some degree by _any_string_method fixture used in test_string_array |
||
], | ||
) | ||
def test_string_array_boolean_array(method, expected): | ||
s = Series(["a", None, "1"], dtype="string") | ||
def test_string_array_boolean_array(nullable_string_dtype, method, expected): | ||
s = Series(["a", None, "1"], dtype=nullable_string_dtype) | ||
result = getattr(s.str, method)() | ||
expected = Series(expected, dtype="boolean") | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_string_array_extract(): | ||
def test_string_array_extract(nullable_string_dtype, request): | ||
# https://github.com/pandas-dev/pandas/issues/30969 | ||
# Only expand=False & multiple groups was failing | ||
a = Series(["a1", "b2", "cc"], dtype="string") | ||
|
||
if nullable_string_dtype == "arrow_string": | ||
reason = "extract does not yet dispatch to array" | ||
mark = pytest.mark.xfail(reason=reason) | ||
request.node.add_marker(mark) | ||
|
||
a = Series(["a1", "b2", "cc"], dtype=nullable_string_dtype) | ||
b = Series(["a1", "b2", "cc"], dtype="object") | ||
pat = r"(\w)(\d)" | ||
|
||
result = a.str.extract(pat, expand=False) | ||
expected = b.str.extract(pat, expand=False) | ||
assert all(result.dtypes == "string") | ||
assert all(result.dtypes == nullable_string_dtype) | ||
|
||
result = result.astype(object) | ||
tm.assert_equal(result, expected) | ||
|
||
|
||
def test_str_get_stringarray_multiple_nans(): | ||
s = Series(pd.array(["a", "ab", pd.NA, "abc"])) | ||
def test_str_get_stringarray_multiple_nans(nullable_string_dtype): | ||
s = Series(pd.array(["a", "ab", pd.NA, "abc"], dtype=nullable_string_dtype)) | ||
result = s.str.get(2) | ||
expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"])) | ||
expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"], dtype=nullable_string_dtype)) | ||
tm.assert_series_equal(result, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR as draft. This will be resolved in a follow-up to #40679. #40679 (comment)
the inference is needed so that the str accessor works.