Skip to content

Commit bbe1fa2

Browse files
simonjayhawkinsTLouf
authored andcommitted
[ArrowStringArray] PERF: str.extract object fallback (pandas-dev#41490)
1 parent ef8df3d commit bbe1fa2

File tree

3 files changed

+55
-52
lines changed

3 files changed

+55
-52
lines changed

asv_bench/benchmarks/strings.py

+42-37
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@
1111
from .pandas_vb_common import tm
1212

1313

14+
class Dtypes:
15+
params = ["str", "string", "arrow_string"]
16+
param_names = ["dtype"]
17+
18+
def setup(self, dtype):
19+
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
20+
21+
try:
22+
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
23+
except ImportError:
24+
raise NotImplementedError
25+
26+
1427
class Construction:
1528

1629
params = ["str", "string"]
@@ -49,18 +62,7 @@ def peakmem_cat_frame_construction(self, dtype):
4962
DataFrame(self.frame_cat_arr, dtype=dtype)
5063

5164

52-
class Methods:
53-
params = ["str", "string", "arrow_string"]
54-
param_names = ["dtype"]
55-
56-
def setup(self, dtype):
57-
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
58-
59-
try:
60-
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
61-
except ImportError:
62-
raise NotImplementedError
63-
65+
class Methods(Dtypes):
6466
def time_center(self, dtype):
6567
self.s.str.center(100)
6668

@@ -211,35 +213,26 @@ def time_cat(self, other_cols, sep, na_rep, na_frac):
211213
self.s.str.cat(others=self.others, sep=sep, na_rep=na_rep)
212214

213215

214-
class Contains:
216+
class Contains(Dtypes):
215217

216-
params = (["str", "string", "arrow_string"], [True, False])
218+
params = (Dtypes.params, [True, False])
217219
param_names = ["dtype", "regex"]
218220

219221
def setup(self, dtype, regex):
220-
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
221-
222-
try:
223-
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
224-
except ImportError:
225-
raise NotImplementedError
222+
super().setup(dtype)
226223

227224
def time_contains(self, dtype, regex):
228225
self.s.str.contains("A", regex=regex)
229226

230227

231-
class Split:
228+
class Split(Dtypes):
232229

233-
params = (["str", "string", "arrow_string"], [True, False])
230+
params = (Dtypes.params, [True, False])
234231
param_names = ["dtype", "expand"]
235232

236233
def setup(self, dtype, expand):
237-
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
238-
239-
try:
240-
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype).str.join("--")
241-
except ImportError:
242-
raise NotImplementedError
234+
super().setup(dtype)
235+
self.s = self.s.str.join("--")
243236

244237
def time_split(self, dtype, expand):
245238
self.s.str.split("--", expand=expand)
@@ -248,17 +241,23 @@ def time_rsplit(self, dtype, expand):
248241
self.s.str.rsplit("--", expand=expand)
249242

250243

251-
class Dummies:
252-
params = ["str", "string", "arrow_string"]
253-
param_names = ["dtype"]
244+
class Extract(Dtypes):
254245

255-
def setup(self, dtype):
256-
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
246+
params = (Dtypes.params, [True, False])
247+
param_names = ["dtype", "expand"]
257248

258-
try:
259-
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype).str.join("|")
260-
except ImportError:
261-
raise NotImplementedError
249+
def setup(self, dtype, expand):
250+
super().setup(dtype)
251+
252+
def time_extract_single_group(self, dtype, expand):
253+
with warnings.catch_warnings(record=True):
254+
self.s.str.extract("(\\w*)A", expand=expand)
255+
256+
257+
class Dummies(Dtypes):
258+
def setup(self, dtype):
259+
super().setup(dtype)
260+
self.s = self.s.str.join("|")
262261

263262
def time_get_dummies(self, dtype):
264263
self.s.str.get_dummies("|")
@@ -279,3 +278,9 @@ def setup(self):
279278
def time_vector_slice(self):
280279
# GH 2602
281280
self.s.str[:5]
281+
282+
283+
class Iter(Dtypes):
284+
def time_iter(self, dtype):
285+
for i in self.s:
286+
pass

pandas/core/strings/accessor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3101,7 +3101,7 @@ def _str_extract_noexpand(arr, pat, flags=0):
31013101
groups_or_na = _groups_or_na_fun(regex)
31023102
result_dtype = _result_dtype(arr)
31033103

3104-
result = np.array([groups_or_na(val)[0] for val in arr], dtype=object)
3104+
result = np.array([groups_or_na(val)[0] for val in np.asarray(arr)], dtype=object)
31053105
# not dispatching, so we have to reconstruct here.
31063106
result = pd_array(result, dtype=result_dtype)
31073107
return result
@@ -3136,7 +3136,7 @@ def _str_extract_frame(arr, pat, flags=0):
31363136
else:
31373137
result_index = None
31383138
return DataFrame(
3139-
[groups_or_na(val) for val in arr],
3139+
[groups_or_na(val) for val in np.asarray(arr)],
31403140
columns=columns,
31413141
index=result_index,
31423142
dtype=result_dtype,

pandas/tests/strings/test_extract.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,23 @@ def test_extract_expand_kwarg(any_string_dtype):
3838

3939

4040
def test_extract_expand_False_mixed_object():
41-
mixed = Series(
42-
[
43-
"aBAD_BAD",
44-
np.nan,
45-
"BAD_b_BAD",
46-
True,
47-
datetime.today(),
48-
"foo",
49-
None,
50-
1,
51-
2.0,
52-
]
41+
ser = Series(
42+
["aBAD_BAD", np.nan, "BAD_b_BAD", True, datetime.today(), "foo", None, 1, 2.0]
5343
)
5444

55-
result = Series(mixed).str.extract(".*(BAD[_]+).*(BAD)", expand=False)
45+
# two groups
46+
result = ser.str.extract(".*(BAD[_]+).*(BAD)", expand=False)
5647
er = [np.nan, np.nan] # empty row
5748
expected = DataFrame([["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er])
5849
tm.assert_frame_equal(result, expected)
5950

51+
# single group
52+
result = ser.str.extract(".*(BAD[_]+).*BAD", expand=False)
53+
expected = Series(
54+
["BAD_", np.nan, "BAD_", np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]
55+
)
56+
tm.assert_series_equal(result, expected)
57+
6058

6159
def test_extract_expand_index_raises():
6260
# GH9980

0 commit comments

Comments
 (0)