Skip to content

[ArrowStringArray] Use utf8_is_* functions from Apache Arrow if available #41041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 64 additions & 29 deletions asv_bench/benchmarks/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,91 +50,126 @@ def peakmem_cat_frame_construction(self, dtype):


class Methods:
def setup(self):
self.s = Series(tm.makeStringIndex(10 ** 5))
params = ["str", "string", "arrow_string"]
param_names = ["dtype"]

def setup(self, dtype):
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

try:
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
except ImportError:
raise NotImplementedError

def time_center(self):
def time_center(self, dtype):
self.s.str.center(100)

def time_count(self):
def time_count(self, dtype):
self.s.str.count("A")

def time_endswith(self):
def time_endswith(self, dtype):
self.s.str.endswith("A")

def time_extract(self):
def time_extract(self, dtype):
with warnings.catch_warnings(record=True):
self.s.str.extract("(\\w*)A(\\w*)")

def time_findall(self):
def time_findall(self, dtype):
self.s.str.findall("[A-Z]+")

def time_find(self):
def time_find(self, dtype):
self.s.str.find("[A-Z]+")

def time_rfind(self):
def time_rfind(self, dtype):
self.s.str.rfind("[A-Z]+")

def time_get(self):
def time_get(self, dtype):
self.s.str.get(0)

def time_len(self):
def time_len(self, dtype):
self.s.str.len()

def time_join(self):
def time_join(self, dtype):
self.s.str.join(" ")

def time_match(self):
def time_match(self, dtype):
self.s.str.match("A")

def time_normalize(self):
def time_normalize(self, dtype):
self.s.str.normalize("NFC")

def time_pad(self):
def time_pad(self, dtype):
self.s.str.pad(100, side="both")

def time_partition(self):
def time_partition(self, dtype):
self.s.str.partition("A")

def time_rpartition(self):
def time_rpartition(self, dtype):
self.s.str.rpartition("A")

def time_replace(self):
def time_replace(self, dtype):
self.s.str.replace("A", "\x01\x01")

def time_translate(self):
def time_translate(self, dtype):
self.s.str.translate({"A": "\x01\x01"})

def time_slice(self):
def time_slice(self, dtype):
self.s.str.slice(5, 15, 2)

def time_startswith(self):
def time_startswith(self, dtype):
self.s.str.startswith("A")

def time_strip(self):
def time_strip(self, dtype):
self.s.str.strip("A")

def time_rstrip(self):
def time_rstrip(self, dtype):
self.s.str.rstrip("A")

def time_lstrip(self):
def time_lstrip(self, dtype):
self.s.str.lstrip("A")

def time_title(self):
def time_title(self, dtype):
self.s.str.title()

def time_upper(self):
def time_upper(self, dtype):
self.s.str.upper()

def time_lower(self):
def time_lower(self, dtype):
self.s.str.lower()

def time_wrap(self):
def time_wrap(self, dtype):
self.s.str.wrap(10)

def time_zfill(self):
def time_zfill(self, dtype):
self.s.str.zfill(10)

def time_isalnum(self, dtype):
self.s.str.isalnum()

def time_isalpha(self, dtype):
self.s.str.isalpha()

def time_isdecimal(self, dtype):
self.s.str.isdecimal()

def time_isdigit(self, dtype):
self.s.str.isdigit()

def time_islower(self, dtype):
self.s.str.islower()

def time_isnumeric(self, dtype):
self.s.str.isnumeric()

def time_isspace(self, dtype):
self.s.str.isspace()

def time_istitle(self, dtype):
self.s.str.istitle()

def time_isupper(self, dtype):
self.s.str.isupper()


class Repeat:

Expand Down
64 changes: 64 additions & 0 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pandas.core import missing
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.indexers import (
check_array_indexer,
validate_indices,
Expand Down Expand Up @@ -758,6 +759,69 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _str_isalnum(self):
if hasattr(pc, "utf8_is_alnum"):
result = pc.utf8_is_alnum(self._data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point (not necessarily this PR), it might be worth benchmarking to see if calling pc.string_is_ascii first to then potentially use pc.ascii_is_alnum instead of pc.utf8_is_alnum could be worth it (which would be assuming that testing whether it's all ascii takes much less time than the benefit from using the faster ascii algorithm vs the utf8 one)

return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isalnum()

def _str_isalpha(self):
if hasattr(pc, "utf8_is_alpha"):
result = pc.utf8_is_alpha(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isalpha()

def _str_isdecimal(self):
if hasattr(pc, "utf8_is_decimal"):
result = pc.utf8_is_decimal(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isdecimal()

def _str_isdigit(self):
if hasattr(pc, "utf8_is_digit"):
result = pc.utf8_is_digit(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isdigit()

def _str_islower(self):
if hasattr(pc, "utf8_is_lower"):
result = pc.utf8_is_lower(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_islower()

def _str_isnumeric(self):
if hasattr(pc, "utf8_is_numeric"):
result = pc.utf8_is_numeric(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isnumeric()

def _str_isspace(self):
if hasattr(pc, "utf8_is_space"):
result = pc.utf8_is_space(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isspace()

def _str_istitle(self):
if hasattr(pc, "utf8_is_title"):
result = pc.utf8_is_title(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_istitle()

def _str_isupper(self):
if hasattr(pc, "utf8_is_upper"):
result = pc.utf8_is_upper(self._data)
return BooleanDtype().__from_arrow__(result)
else:
return super()._str_isupper()

def _str_lower(self):
return type(self)(pc.utf8_lower(self._data))

Expand Down
3 changes: 2 additions & 1 deletion pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3002,8 +3002,9 @@ def _result_dtype(arr):
# ideally we just pass `dtype=arr.dtype` unconditionally, but this fails
# when the list of values is empty.
from pandas.core.arrays.string_ import StringDtype
from pandas.core.arrays.string_arrow import ArrowStringDtype

if isinstance(arr.dtype, StringDtype):
if isinstance(arr.dtype, (StringDtype, ArrowStringDtype)):
return arr.dtype.name
else:
return object
Expand Down
17 changes: 2 additions & 15 deletions pandas/tests/strings/test_string_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,11 @@
)


def test_string_array(nullable_string_dtype, any_string_method, request):
def test_string_array(nullable_string_dtype, any_string_method):
method_name, args, kwargs = any_string_method
if method_name == "decode":
pytest.skip("decode requires bytes.")

if nullable_string_dtype == "arrow_string" and method_name in {
"extract",
"extractall",
}:
reason = "extract/extractall does not yet dispatch to array"
mark = pytest.mark.xfail(reason=reason)
request.node.add_marker(mark)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is extract fixed now? (but not in this PR?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tests no longer fail. there is a change in this PR that "fixes" by special casing ArrowStringArray (like StringArray). extract/extractall will still need to be updated to dispatch to the array. not in this PR. see #41041 (comment)

the change to pandas/core/strings/accessor.py makes ArrowStringArray work like StringArray. If we don't add the change I would need to xfail test_empty_str_methods which totally defeats the purpose of parameterising the tests to get extra test coverage for the is_methods. The alternative is to split test_empty_str_methods and xfail the extract/extractall tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks for the explanation. No need to split off, I was just wondering how the changes in this PR "fixed" it ;)


data = ["a", "bb", np.nan, "ccc"]
a = Series(data, dtype=object)
b = Series(data, dtype=nullable_string_dtype)
Expand Down Expand Up @@ -93,15 +85,10 @@ def test_string_array_boolean_array(nullable_string_dtype, method, expected):
tm.assert_series_equal(result, expected)


def test_string_array_extract(nullable_string_dtype, request):
def test_string_array_extract(nullable_string_dtype):
# https://github.com/pandas-dev/pandas/issues/30969
# Only expand=False & multiple groups was failing

if nullable_string_dtype == "arrow_string":
reason = "extract does not yet dispatch to array"
mark = pytest.mark.xfail(reason=reason)
request.node.add_marker(mark)

a = Series(["a1", "b2", "cc"], dtype=nullable_string_dtype)
b = Series(["a1", "b2", "cc"], dtype="object")
pat = r"(\w)(\d)"
Expand Down
Loading