Skip to content

[ArrowStringArray] PERF: str.extract object fallback #41490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 42 additions & 37 deletions asv_bench/benchmarks/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
from .pandas_vb_common import tm


class Dtypes:
params = ["str", "string", "arrow_string"]
param_names = ["dtype"]

def setup(self, dtype):
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

try:
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
except ImportError:
raise NotImplementedError


class Construction:

params = ["str", "string"]
Expand Down Expand Up @@ -49,18 +62,7 @@ def peakmem_cat_frame_construction(self, dtype):
DataFrame(self.frame_cat_arr, dtype=dtype)


class Methods:
params = ["str", "string", "arrow_string"]
param_names = ["dtype"]

def setup(self, dtype):
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

try:
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
except ImportError:
raise NotImplementedError

class Methods(Dtypes):
def time_center(self, dtype):
self.s.str.center(100)

Expand Down Expand Up @@ -211,35 +213,26 @@ def time_cat(self, other_cols, sep, na_rep, na_frac):
self.s.str.cat(others=self.others, sep=sep, na_rep=na_rep)


class Contains:
class Contains(Dtypes):

params = (["str", "string", "arrow_string"], [True, False])
params = (Dtypes.params, [True, False])
param_names = ["dtype", "regex"]

def setup(self, dtype, regex):
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

try:
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
except ImportError:
raise NotImplementedError
super().setup(dtype)

def time_contains(self, dtype, regex):
self.s.str.contains("A", regex=regex)


class Split:
class Split(Dtypes):

params = (["str", "string", "arrow_string"], [True, False])
params = (Dtypes.params, [True, False])
param_names = ["dtype", "expand"]

def setup(self, dtype, expand):
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

try:
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype).str.join("--")
except ImportError:
raise NotImplementedError
super().setup(dtype)
self.s = self.s.str.join("--")

def time_split(self, dtype, expand):
self.s.str.split("--", expand=expand)
Expand All @@ -248,17 +241,23 @@ def time_rsplit(self, dtype, expand):
self.s.str.rsplit("--", expand=expand)


class Dummies:
params = ["str", "string", "arrow_string"]
param_names = ["dtype"]
class Extract(Dtypes):

def setup(self, dtype):
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
params = (Dtypes.params, [True, False])
param_names = ["dtype", "expand"]

try:
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype).str.join("|")
except ImportError:
raise NotImplementedError
def setup(self, dtype, expand):
super().setup(dtype)

def time_extract_single_group(self, dtype, expand):
with warnings.catch_warnings(record=True):
self.s.str.extract("(\\w*)A", expand=expand)


class Dummies(Dtypes):
def setup(self, dtype):
super().setup(dtype)
self.s = self.s.str.join("|")

def time_get_dummies(self, dtype):
self.s.str.get_dummies("|")
Expand All @@ -279,3 +278,9 @@ def setup(self):
def time_vector_slice(self):
# GH 2602
self.s.str[:5]


class Iter(Dtypes):
def time_iter(self, dtype):
for i in self.s:
pass
4 changes: 2 additions & 2 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,7 +3101,7 @@ def _str_extract_noexpand(arr, pat, flags=0):
groups_or_na = _groups_or_na_fun(regex)
result_dtype = _result_dtype(arr)

result = np.array([groups_or_na(val)[0] for val in arr], dtype=object)
result = np.array([groups_or_na(val)[0] for val in np.asarray(arr)], dtype=object)
# not dispatching, so we have to reconstruct here.
result = pd_array(result, dtype=result_dtype)
return result
Expand Down Expand Up @@ -3136,7 +3136,7 @@ def _str_extract_frame(arr, pat, flags=0):
else:
result_index = None
return DataFrame(
[groups_or_na(val) for val in arr],
[groups_or_na(val) for val in np.asarray(arr)],
columns=columns,
index=result_index,
dtype=result_dtype,
Expand Down
24 changes: 11 additions & 13 deletions pandas/tests/strings/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,23 @@ def test_extract_expand_kwarg(any_string_dtype):


def test_extract_expand_False_mixed_object():
mixed = Series(
[
"aBAD_BAD",
np.nan,
"BAD_b_BAD",
True,
datetime.today(),
"foo",
None,
1,
2.0,
]
ser = Series(
["aBAD_BAD", np.nan, "BAD_b_BAD", True, datetime.today(), "foo", None, 1, 2.0]
)

result = Series(mixed).str.extract(".*(BAD[_]+).*(BAD)", expand=False)
# two groups
result = ser.str.extract(".*(BAD[_]+).*(BAD)", expand=False)
er = [np.nan, np.nan] # empty row
expected = DataFrame([["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er])
tm.assert_frame_equal(result, expected)

# single group
result = ser.str.extract(".*(BAD[_]+).*BAD", expand=False)
expected = Series(
["BAD_", np.nan, "BAD_", np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]
)
tm.assert_series_equal(result, expected)


def test_extract_expand_index_raises():
# GH9980
Expand Down