Skip to content

Commit 9806b75

Browse files
[ArrowStringArray] PERF: use pa.compute.match_substring_regex for str.match if available (#41326)
1 parent 0677491 commit 9806b75

File tree

4 files changed

+63
-31
lines changed

4 files changed

+63
-31
lines changed

pandas/core/arrays/string_arrow.py

+8
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Dtype,
2020
NpDtype,
2121
PositionalIndexer,
22+
Scalar,
2223
type_t,
2324
)
2425
from pandas.util._decorators import doc
@@ -810,6 +811,13 @@ def _str_endswith(self, pat, na=None):
810811
else:
811812
return super()._str_endswith(pat, na)
812813

814+
def _str_match(
815+
self, pat: str, case: bool = True, flags: int = 0, na: Scalar = None
816+
):
817+
if not pat.startswith("^"):
818+
pat = "^" + pat
819+
return self._str_contains(pat, case, flags, na, regex=True)
820+
813821
def _str_isalnum(self):
814822
result = pc.utf8_is_alnum(self._data)
815823
return BooleanDtype().__from_arrow__(result)

pandas/core/strings/base.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,7 @@ def _str_repeat(self, repeats):
6161

6262
@abc.abstractmethod
6363
def _str_match(
64-
self,
65-
pat: Union[str, Pattern],
66-
case: bool = True,
67-
flags: int = 0,
68-
na: Scalar = np.nan,
64+
self, pat: str, case: bool = True, flags: int = 0, na: Scalar = np.nan
6965
):
7066
pass
7167

pandas/core/strings/object_array.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,7 @@ def rep(x, r):
186186
return result
187187

188188
def _str_match(
189-
self,
190-
pat: Union[str, Pattern],
191-
case: bool = True,
192-
flags: int = 0,
193-
na: Scalar = None,
189+
self, pat: str, case: bool = True, flags: int = 0, na: Scalar = None
194190
):
195191
if not case:
196192
flags |= re.IGNORECASE

pandas/tests/strings/test_find_replace.py

+53-21
Original file line numberDiff line numberDiff line change
@@ -409,19 +409,39 @@ def test_replace_literal(any_string_dtype):
409409
values.str.replace(compiled_pat, "", regex=False)
410410

411411

412-
def test_match():
412+
def test_match(any_string_dtype):
413413
# New match behavior introduced in 0.13
414-
values = Series(["fooBAD__barBAD", np.nan, "foo"])
414+
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
415+
416+
values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype)
415417
result = values.str.match(".*(BAD[_]+).*(BAD)")
416-
exp = Series([True, np.nan, False])
417-
tm.assert_series_equal(result, exp)
418+
expected = Series([True, np.nan, False], dtype=expected_dtype)
419+
tm.assert_series_equal(result, expected)
418420

419-
values = Series(["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"])
421+
values = Series(
422+
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
423+
)
420424
result = values.str.match(".*BAD[_]+.*BAD")
421-
exp = Series([True, True, np.nan, False])
422-
tm.assert_series_equal(result, exp)
425+
expected = Series([True, True, np.nan, False], dtype=expected_dtype)
426+
tm.assert_series_equal(result, expected)
423427

424-
# mixed
428+
result = values.str.match("BAD[_]+.*BAD")
429+
expected = Series([False, True, np.nan, False], dtype=expected_dtype)
430+
tm.assert_series_equal(result, expected)
431+
432+
values = Series(
433+
["fooBAD__barBAD", "^BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
434+
)
435+
result = values.str.match("^BAD[_]+.*BAD")
436+
expected = Series([False, False, np.nan, False], dtype=expected_dtype)
437+
tm.assert_series_equal(result, expected)
438+
439+
result = values.str.match("\\^BAD[_]+.*BAD")
440+
expected = Series([False, True, np.nan, False], dtype=expected_dtype)
441+
tm.assert_series_equal(result, expected)
442+
443+
444+
def test_match_mixed_object():
425445
mixed = Series(
426446
[
427447
"aBAD_BAD",
@@ -435,22 +455,34 @@ def test_match():
435455
2.0,
436456
]
437457
)
438-
rs = Series(mixed).str.match(".*(BAD[_]+).*(BAD)")
439-
xp = Series([True, np.nan, True, np.nan, np.nan, False, np.nan, np.nan, np.nan])
440-
assert isinstance(rs, Series)
441-
tm.assert_series_equal(rs, xp)
458+
result = Series(mixed).str.match(".*(BAD[_]+).*(BAD)")
459+
expected = Series(
460+
[True, np.nan, True, np.nan, np.nan, False, np.nan, np.nan, np.nan]
461+
)
462+
assert isinstance(result, Series)
463+
tm.assert_series_equal(result, expected)
464+
442465

443-
# na GH #6609
444-
res = Series(["a", 0, np.nan]).str.match("a", na=False)
445-
exp = Series([True, False, False])
446-
tm.assert_series_equal(exp, res)
447-
res = Series(["a", 0, np.nan]).str.match("a")
448-
exp = Series([True, np.nan, np.nan])
449-
tm.assert_series_equal(exp, res)
466+
def test_match_na_kwarg(any_string_dtype):
467+
# GH #6609
468+
s = Series(["a", "b", np.nan], dtype=any_string_dtype)
450469

451-
values = Series(["ab", "AB", "abc", "ABC"])
470+
result = s.str.match("a", na=False)
471+
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
472+
expected = Series([True, False, False], dtype=expected_dtype)
473+
tm.assert_series_equal(result, expected)
474+
475+
result = s.str.match("a")
476+
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
477+
expected = Series([True, False, np.nan], dtype=expected_dtype)
478+
tm.assert_series_equal(result, expected)
479+
480+
481+
def test_match_case_kwarg(any_string_dtype):
482+
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
452483
result = values.str.match("ab", case=False)
453-
expected = Series([True, True, True, True])
484+
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
485+
expected = Series([True, True, True, True], dtype=expected_dtype)
454486
tm.assert_series_equal(result, expected)
455487

456488

0 commit comments

Comments
 (0)