Skip to content

Commit 0fba740

Browse files
[ArrowStringArray] implement ArrowStringArray._str_contains (#41025)
1 parent 8de6276 commit 0fba740

File tree

3 files changed

+151
-58
lines changed

3 files changed

+151
-58
lines changed

asv_bench/benchmarks/strings.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,18 @@ def time_cat(self, other_cols, sep, na_rep, na_frac):
213213

214214
class Contains:
215215

216-
params = [True, False]
217-
param_names = ["regex"]
216+
params = (["str", "string", "arrow_string"], [True, False])
217+
param_names = ["dtype", "regex"]
218+
219+
def setup(self, dtype, regex):
220+
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
218221

219-
def setup(self, regex):
220-
self.s = Series(tm.makeStringIndex(10 ** 5))
222+
try:
223+
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
224+
except ImportError:
225+
raise NotImplementedError
221226

222-
def time_contains(self, regex):
227+
def time_contains(self, dtype, regex):
223228
self.s.str.contains("A", regex=regex)
224229

225230

pandas/core/arrays/string_arrow.py

+10
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,16 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
759759
# -> We don't know the result type. E.g. `.get` can return anything.
760760
return lib.map_infer_mask(arr, f, mask.view("uint8"))
761761

762+
def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
763+
if not regex and case:
764+
result = pc.match_substring(self._data, pat)
765+
result = BooleanDtype().__from_arrow__(result)
766+
if not isna(na):
767+
result[isna(result)] = bool(na)
768+
return result
769+
else:
770+
return super()._str_contains(pat, case, flags, na, regex)
771+
762772
def _str_isalnum(self):
763773
if hasattr(pc, "utf8_is_alnum"):
764774
result = pc.utf8_is_alnum(self._data)

pandas/tests/strings/test_find_replace.py

+131-53
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
import pytest
66

7+
import pandas.util._test_decorators as td
8+
79
import pandas as pd
810
from pandas import (
911
Index,
@@ -12,79 +14,118 @@
1214
)
1315

1416

15-
def test_contains():
17+
@pytest.fixture(
18+
params=[
19+
"object",
20+
"string",
21+
pytest.param(
22+
"arrow_string", marks=td.skip_if_no("pyarrow", min_version="1.0.0")
23+
),
24+
]
25+
)
26+
def any_string_dtype(request):
27+
"""
28+
Parametrized fixture for string dtypes.
29+
* 'object'
30+
* 'string'
31+
* 'arrow_string'
32+
"""
33+
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
34+
35+
return request.param
36+
37+
38+
def test_contains(any_string_dtype):
1639
values = np.array(
1740
["foo", np.nan, "fooommm__foo", "mmm_", "foommm[_]+bar"], dtype=np.object_
1841
)
19-
values = Series(values)
42+
values = Series(values, dtype=any_string_dtype)
2043
pat = "mmm[_]+"
2144

2245
result = values.str.contains(pat)
23-
expected = Series(np.array([False, np.nan, True, True, False], dtype=np.object_))
46+
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
47+
expected = Series(
48+
np.array([False, np.nan, True, True, False], dtype=np.object_),
49+
dtype=expected_dtype,
50+
)
2451
tm.assert_series_equal(result, expected)
2552

2653
result = values.str.contains(pat, regex=False)
27-
expected = Series(np.array([False, np.nan, False, False, True], dtype=np.object_))
54+
expected = Series(
55+
np.array([False, np.nan, False, False, True], dtype=np.object_),
56+
dtype=expected_dtype,
57+
)
2858
tm.assert_series_equal(result, expected)
2959

30-
values = Series(np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=object))
60+
values = Series(
61+
np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=object),
62+
dtype=any_string_dtype,
63+
)
3164
result = values.str.contains(pat)
32-
expected = Series(np.array([False, False, True, True]))
33-
assert result.dtype == np.bool_
65+
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
66+
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
3467
tm.assert_series_equal(result, expected)
3568

3669
# case insensitive using regex
37-
values = Series(np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object))
70+
values = Series(
71+
np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object),
72+
dtype=any_string_dtype,
73+
)
3874
result = values.str.contains("FOO|mmm", case=False)
39-
expected = Series(np.array([True, False, True, True]))
75+
expected = Series(np.array([True, False, True, True]), dtype=expected_dtype)
4076
tm.assert_series_equal(result, expected)
4177

4278
# case insensitive without regex
43-
result = Series(values).str.contains("foo", regex=False, case=False)
44-
expected = Series(np.array([True, False, True, False]))
79+
result = values.str.contains("foo", regex=False, case=False)
80+
expected = Series(np.array([True, False, True, False]), dtype=expected_dtype)
4581
tm.assert_series_equal(result, expected)
4682

47-
# mixed
83+
# unicode
84+
values = Series(
85+
np.array(["foo", np.nan, "fooommm__foo", "mmm_"], dtype=np.object_),
86+
dtype=any_string_dtype,
87+
)
88+
pat = "mmm[_]+"
89+
90+
result = values.str.contains(pat)
91+
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
92+
expected = Series(
93+
np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype
94+
)
95+
tm.assert_series_equal(result, expected)
96+
97+
result = values.str.contains(pat, na=False)
98+
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
99+
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
100+
tm.assert_series_equal(result, expected)
101+
102+
values = Series(
103+
np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=np.object_),
104+
dtype=any_string_dtype,
105+
)
106+
result = values.str.contains(pat)
107+
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
108+
tm.assert_series_equal(result, expected)
109+
110+
111+
def test_contains_object_mixed():
48112
mixed = Series(
49113
np.array(
50114
["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0],
51115
dtype=object,
52116
)
53117
)
54-
rs = mixed.str.contains("o")
55-
xp = Series(
118+
result = mixed.str.contains("o")
119+
expected = Series(
56120
np.array(
57121
[False, np.nan, False, np.nan, np.nan, True, np.nan, np.nan, np.nan],
58122
dtype=np.object_,
59123
)
60124
)
61-
tm.assert_series_equal(rs, xp)
62-
63-
rs = mixed.str.contains("o")
64-
xp = Series([False, np.nan, False, np.nan, np.nan, True, np.nan, np.nan, np.nan])
65-
assert isinstance(rs, Series)
66-
tm.assert_series_equal(rs, xp)
67-
68-
# unicode
69-
values = Series(np.array(["foo", np.nan, "fooommm__foo", "mmm_"], dtype=np.object_))
70-
pat = "mmm[_]+"
71-
72-
result = values.str.contains(pat)
73-
expected = Series(np.array([False, np.nan, True, True], dtype=np.object_))
74-
tm.assert_series_equal(result, expected)
75-
76-
result = values.str.contains(pat, na=False)
77-
expected = Series(np.array([False, False, True, True]))
78-
tm.assert_series_equal(result, expected)
79-
80-
values = Series(np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=np.object_))
81-
result = values.str.contains(pat)
82-
expected = Series(np.array([False, False, True, True]))
83-
assert result.dtype == np.bool_
84125
tm.assert_series_equal(result, expected)
85126

86127

87-
def test_contains_for_object_category():
128+
def test_contains_na_kwarg_for_object_category():
88129
# gh 22158
89130

90131
# na for category
@@ -108,6 +149,29 @@ def test_contains_for_object_category():
108149
tm.assert_series_equal(result, expected)
109150

110151

152+
@pytest.mark.parametrize(
153+
"na, expected",
154+
[
155+
(None, pd.NA),
156+
(True, True),
157+
(False, False),
158+
(0, False),
159+
(3, True),
160+
(np.nan, pd.NA),
161+
],
162+
)
163+
@pytest.mark.parametrize("regex", [True, False])
164+
def test_contains_na_kwarg_for_nullable_string_dtype(
165+
nullable_string_dtype, na, expected, regex
166+
):
167+
# https://github.com/pandas-dev/pandas/pull/41025#issuecomment-824062416
168+
169+
values = Series(["a", "b", "c", "a", np.nan], dtype=nullable_string_dtype)
170+
result = values.str.contains("a", na=na, regex=regex)
171+
expected = Series([True, False, False, True, expected], dtype="boolean")
172+
tm.assert_series_equal(result, expected)
173+
174+
111175
@pytest.mark.parametrize("dtype", [None, "category"])
112176
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
113177
@pytest.mark.parametrize("na", [True, False])
@@ -508,59 +572,73 @@ def _check(result, expected):
508572
tm.assert_series_equal(result, expected)
509573

510574

511-
def test_contains_moar():
575+
def test_contains_moar(any_string_dtype):
512576
# PR #1179
513-
s = Series(["A", "B", "C", "Aaba", "Baca", "", np.nan, "CABA", "dog", "cat"])
577+
s = Series(
578+
["A", "B", "C", "Aaba", "Baca", "", np.nan, "CABA", "dog", "cat"],
579+
dtype=any_string_dtype,
580+
)
514581

515582
result = s.str.contains("a")
583+
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
516584
expected = Series(
517-
[False, False, False, True, True, False, np.nan, False, False, True]
585+
[False, False, False, True, True, False, np.nan, False, False, True],
586+
dtype=expected_dtype,
518587
)
519588
tm.assert_series_equal(result, expected)
520589

521590
result = s.str.contains("a", case=False)
522591
expected = Series(
523-
[True, False, False, True, True, False, np.nan, True, False, True]
592+
[True, False, False, True, True, False, np.nan, True, False, True],
593+
dtype=expected_dtype,
524594
)
525595
tm.assert_series_equal(result, expected)
526596

527597
result = s.str.contains("Aa")
528598
expected = Series(
529-
[False, False, False, True, False, False, np.nan, False, False, False]
599+
[False, False, False, True, False, False, np.nan, False, False, False],
600+
dtype=expected_dtype,
530601
)
531602
tm.assert_series_equal(result, expected)
532603

533604
result = s.str.contains("ba")
534605
expected = Series(
535-
[False, False, False, True, False, False, np.nan, False, False, False]
606+
[False, False, False, True, False, False, np.nan, False, False, False],
607+
dtype=expected_dtype,
536608
)
537609
tm.assert_series_equal(result, expected)
538610

539611
result = s.str.contains("ba", case=False)
540612
expected = Series(
541-
[False, False, False, True, True, False, np.nan, True, False, False]
613+
[False, False, False, True, True, False, np.nan, True, False, False],
614+
dtype=expected_dtype,
542615
)
543616
tm.assert_series_equal(result, expected)
544617

545618

546-
def test_contains_nan():
619+
def test_contains_nan(any_string_dtype):
547620
# PR #14171
548-
s = Series([np.nan, np.nan, np.nan], dtype=np.object_)
621+
s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype)
549622

550623
result = s.str.contains("foo", na=False)
551-
expected = Series([False, False, False], dtype=np.bool_)
624+
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
625+
expected = Series([False, False, False], dtype=expected_dtype)
552626
tm.assert_series_equal(result, expected)
553627

554628
result = s.str.contains("foo", na=True)
555-
expected = Series([True, True, True], dtype=np.bool_)
629+
expected = Series([True, True, True], dtype=expected_dtype)
556630
tm.assert_series_equal(result, expected)
557631

558632
result = s.str.contains("foo", na="foo")
559-
expected = Series(["foo", "foo", "foo"], dtype=np.object_)
633+
if any_string_dtype == "object":
634+
expected = Series(["foo", "foo", "foo"], dtype=np.object_)
635+
else:
636+
expected = Series([True, True, True], dtype="boolean")
560637
tm.assert_series_equal(result, expected)
561638

562639
result = s.str.contains("foo")
563-
expected = Series([np.nan, np.nan, np.nan], dtype=np.object_)
640+
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
641+
expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype)
564642
tm.assert_series_equal(result, expected)
565643

566644

@@ -609,14 +687,14 @@ def test_replace_moar():
609687
tm.assert_series_equal(result, expected)
610688

611689

612-
def test_match_findall_flags():
690+
def test_flags_kwarg(any_string_dtype):
613691
data = {
614692
"Dave": "[email protected]",
615693
"Steve": "[email protected]",
616694
617695
"Wes": np.nan,
618696
}
619-
data = Series(data)
697+
data = Series(data, dtype=any_string_dtype)
620698

621699
pat = r"([A-Z0-9._%+-]+)@([A-Z0-9.-]+)\.([A-Z]{2,4})"
622700

0 commit comments

Comments
 (0)