Skip to content

Commit 0740197

Browse files
simonjayhawkinsyeshsurya
authored andcommitted
[ArrowStringArray] Use utf8_is_* functions from Apache Arrow if available (pandas-dev#41041)
1 parent 87f3a7d commit 0740197

File tree

5 files changed

+189
-71
lines changed

5 files changed

+189
-71
lines changed

asv_bench/benchmarks/strings.py

+64-29
Original file line numberDiff line numberDiff line change
@@ -50,91 +50,126 @@ def peakmem_cat_frame_construction(self, dtype):
5050

5151

5252
class Methods:
53-
def setup(self):
54-
self.s = Series(tm.makeStringIndex(10 ** 5))
53+
params = ["str", "string", "arrow_string"]
54+
param_names = ["dtype"]
55+
56+
def setup(self, dtype):
57+
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
58+
59+
try:
60+
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
61+
except ImportError:
62+
raise NotImplementedError
5563

56-
def time_center(self):
64+
def time_center(self, dtype):
5765
self.s.str.center(100)
5866

59-
def time_count(self):
67+
def time_count(self, dtype):
6068
self.s.str.count("A")
6169

62-
def time_endswith(self):
70+
def time_endswith(self, dtype):
6371
self.s.str.endswith("A")
6472

65-
def time_extract(self):
73+
def time_extract(self, dtype):
6674
with warnings.catch_warnings(record=True):
6775
self.s.str.extract("(\\w*)A(\\w*)")
6876

69-
def time_findall(self):
77+
def time_findall(self, dtype):
7078
self.s.str.findall("[A-Z]+")
7179

72-
def time_find(self):
80+
def time_find(self, dtype):
7381
self.s.str.find("[A-Z]+")
7482

75-
def time_rfind(self):
83+
def time_rfind(self, dtype):
7684
self.s.str.rfind("[A-Z]+")
7785

78-
def time_get(self):
86+
def time_get(self, dtype):
7987
self.s.str.get(0)
8088

81-
def time_len(self):
89+
def time_len(self, dtype):
8290
self.s.str.len()
8391

84-
def time_join(self):
92+
def time_join(self, dtype):
8593
self.s.str.join(" ")
8694

87-
def time_match(self):
95+
def time_match(self, dtype):
8896
self.s.str.match("A")
8997

90-
def time_normalize(self):
98+
def time_normalize(self, dtype):
9199
self.s.str.normalize("NFC")
92100

93-
def time_pad(self):
101+
def time_pad(self, dtype):
94102
self.s.str.pad(100, side="both")
95103

96-
def time_partition(self):
104+
def time_partition(self, dtype):
97105
self.s.str.partition("A")
98106

99-
def time_rpartition(self):
107+
def time_rpartition(self, dtype):
100108
self.s.str.rpartition("A")
101109

102-
def time_replace(self):
110+
def time_replace(self, dtype):
103111
self.s.str.replace("A", "\x01\x01")
104112

105-
def time_translate(self):
113+
def time_translate(self, dtype):
106114
self.s.str.translate({"A": "\x01\x01"})
107115

108-
def time_slice(self):
116+
def time_slice(self, dtype):
109117
self.s.str.slice(5, 15, 2)
110118

111-
def time_startswith(self):
119+
def time_startswith(self, dtype):
112120
self.s.str.startswith("A")
113121

114-
def time_strip(self):
122+
def time_strip(self, dtype):
115123
self.s.str.strip("A")
116124

117-
def time_rstrip(self):
125+
def time_rstrip(self, dtype):
118126
self.s.str.rstrip("A")
119127

120-
def time_lstrip(self):
128+
def time_lstrip(self, dtype):
121129
self.s.str.lstrip("A")
122130

123-
def time_title(self):
131+
def time_title(self, dtype):
124132
self.s.str.title()
125133

126-
def time_upper(self):
134+
def time_upper(self, dtype):
127135
self.s.str.upper()
128136

129-
def time_lower(self):
137+
def time_lower(self, dtype):
130138
self.s.str.lower()
131139

132-
def time_wrap(self):
140+
def time_wrap(self, dtype):
133141
self.s.str.wrap(10)
134142

135-
def time_zfill(self):
143+
def time_zfill(self, dtype):
136144
self.s.str.zfill(10)
137145

146+
def time_isalnum(self, dtype):
147+
self.s.str.isalnum()
148+
149+
def time_isalpha(self, dtype):
150+
self.s.str.isalpha()
151+
152+
def time_isdecimal(self, dtype):
153+
self.s.str.isdecimal()
154+
155+
def time_isdigit(self, dtype):
156+
self.s.str.isdigit()
157+
158+
def time_islower(self, dtype):
159+
self.s.str.islower()
160+
161+
def time_isnumeric(self, dtype):
162+
self.s.str.isnumeric()
163+
164+
def time_isspace(self, dtype):
165+
self.s.str.isspace()
166+
167+
def time_istitle(self, dtype):
168+
self.s.str.istitle()
169+
170+
def time_isupper(self, dtype):
171+
self.s.str.isupper()
172+
138173

139174
class Repeat:
140175

pandas/core/arrays/string_arrow.py

+64
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from pandas.core import missing
4040
from pandas.core.arraylike import OpsMixin
4141
from pandas.core.arrays.base import ExtensionArray
42+
from pandas.core.arrays.boolean import BooleanDtype
4243
from pandas.core.indexers import (
4344
check_array_indexer,
4445
validate_indices,
@@ -758,6 +759,69 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
758759
# -> We don't know the result type. E.g. `.get` can return anything.
759760
return lib.map_infer_mask(arr, f, mask.view("uint8"))
760761

762+
def _str_isalnum(self):
763+
if hasattr(pc, "utf8_is_alnum"):
764+
result = pc.utf8_is_alnum(self._data)
765+
return BooleanDtype().__from_arrow__(result)
766+
else:
767+
return super()._str_isalnum()
768+
769+
def _str_isalpha(self):
770+
if hasattr(pc, "utf8_is_alpha"):
771+
result = pc.utf8_is_alpha(self._data)
772+
return BooleanDtype().__from_arrow__(result)
773+
else:
774+
return super()._str_isalpha()
775+
776+
def _str_isdecimal(self):
777+
if hasattr(pc, "utf8_is_decimal"):
778+
result = pc.utf8_is_decimal(self._data)
779+
return BooleanDtype().__from_arrow__(result)
780+
else:
781+
return super()._str_isdecimal()
782+
783+
def _str_isdigit(self):
784+
if hasattr(pc, "utf8_is_digit"):
785+
result = pc.utf8_is_digit(self._data)
786+
return BooleanDtype().__from_arrow__(result)
787+
else:
788+
return super()._str_isdigit()
789+
790+
def _str_islower(self):
791+
if hasattr(pc, "utf8_is_lower"):
792+
result = pc.utf8_is_lower(self._data)
793+
return BooleanDtype().__from_arrow__(result)
794+
else:
795+
return super()._str_islower()
796+
797+
def _str_isnumeric(self):
798+
if hasattr(pc, "utf8_is_numeric"):
799+
result = pc.utf8_is_numeric(self._data)
800+
return BooleanDtype().__from_arrow__(result)
801+
else:
802+
return super()._str_isnumeric()
803+
804+
def _str_isspace(self):
805+
if hasattr(pc, "utf8_is_space"):
806+
result = pc.utf8_is_space(self._data)
807+
return BooleanDtype().__from_arrow__(result)
808+
else:
809+
return super()._str_isspace()
810+
811+
def _str_istitle(self):
812+
if hasattr(pc, "utf8_is_title"):
813+
result = pc.utf8_is_title(self._data)
814+
return BooleanDtype().__from_arrow__(result)
815+
else:
816+
return super()._str_istitle()
817+
818+
def _str_isupper(self):
819+
if hasattr(pc, "utf8_is_upper"):
820+
result = pc.utf8_is_upper(self._data)
821+
return BooleanDtype().__from_arrow__(result)
822+
else:
823+
return super()._str_isupper()
824+
761825
def _str_lower(self):
762826
return type(self)(pc.utf8_lower(self._data))
763827

pandas/core/strings/accessor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3002,8 +3002,9 @@ def _result_dtype(arr):
30023002
# ideally we just pass `dtype=arr.dtype` unconditionally, but this fails
30033003
# when the list of values is empty.
30043004
from pandas.core.arrays.string_ import StringDtype
3005+
from pandas.core.arrays.string_arrow import ArrowStringDtype
30053006

3006-
if isinstance(arr.dtype, StringDtype):
3007+
if isinstance(arr.dtype, (StringDtype, ArrowStringDtype)):
30073008
return arr.dtype.name
30083009
else:
30093010
return object

pandas/tests/strings/test_string_array.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,11 @@
1313
)
1414

1515

16-
def test_string_array(nullable_string_dtype, any_string_method, request):
16+
def test_string_array(nullable_string_dtype, any_string_method):
1717
method_name, args, kwargs = any_string_method
1818
if method_name == "decode":
1919
pytest.skip("decode requires bytes.")
2020

21-
if nullable_string_dtype == "arrow_string" and method_name in {
22-
"extract",
23-
"extractall",
24-
}:
25-
reason = "extract/extractall does not yet dispatch to array"
26-
mark = pytest.mark.xfail(reason=reason)
27-
request.node.add_marker(mark)
28-
2921
data = ["a", "bb", np.nan, "ccc"]
3022
a = Series(data, dtype=object)
3123
b = Series(data, dtype=nullable_string_dtype)
@@ -93,15 +85,10 @@ def test_string_array_boolean_array(nullable_string_dtype, method, expected):
9385
tm.assert_series_equal(result, expected)
9486

9587

96-
def test_string_array_extract(nullable_string_dtype, request):
88+
def test_string_array_extract(nullable_string_dtype):
9789
# https://github.com/pandas-dev/pandas/issues/30969
9890
# Only expand=False & multiple groups was failing
9991

100-
if nullable_string_dtype == "arrow_string":
101-
reason = "extract does not yet dispatch to array"
102-
mark = pytest.mark.xfail(reason=reason)
103-
request.node.add_marker(mark)
104-
10592
a = Series(["a1", "b2", "cc"], dtype=nullable_string_dtype)
10693
b = Series(["a1", "b2", "cc"], dtype="object")
10794
pat = r"(\w)(\d)"

0 commit comments

Comments
 (0)