diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index c3cc8b3643ed2..5f85a33dd2753 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -17,6 +17,10 @@ # -------------------------------------------------------------------------------------- +def using_pyarrow(dtype): + return dtype in ("string[pyarrow]",) + + def test_contains(any_string_dtype): values = np.array( ["foo", np.nan, "fooommm__foo", "mmm_", "foommm[_]+bar"], dtype=np.object_ @@ -378,9 +382,7 @@ def test_replace_mixed_object(): def test_replace_unicode(any_string_dtype): ser = Series([b"abcd,\xc3\xa0".decode("utf-8")], dtype=any_string_dtype) expected = Series([b"abcd, \xc3\xa0".decode("utf-8")], dtype=any_string_dtype) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace(r"(?<=\w),(?=\w)", ", ", flags=re.UNICODE, regex=True) tm.assert_series_equal(result, expected) @@ -401,9 +403,7 @@ def test_replace_callable(any_string_dtype): # test with callable repl = lambda m: m.group(0).swapcase() - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace("[a-z][A-Z]{2}", repl, n=2, regex=True) expected = Series(["foObaD__baRbaD", np.nan], dtype=any_string_dtype) tm.assert_series_equal(result, expected) @@ -423,7 +423,7 @@ def test_replace_callable_raises(any_string_dtype, repl): ) with pytest.raises(TypeError, match=msg): with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" + PerformanceWarning, using_pyarrow(any_string_dtype) ): values.str.replace("a", repl, regex=True) @@ -433,9 +433,7 @@ def test_replace_callable_named_groups(any_string_dtype): ser = Series(["Foo Bar Baz", np.nan], dtype=any_string_dtype) pat = r"(?P\w+) (?P\w+) (?P\w+)" repl = lambda m: m.group("middle").swapcase() - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace(pat, repl, regex=True) expected = Series(["bAR", np.nan], dtype=any_string_dtype) tm.assert_series_equal(result, expected) @@ -447,16 +445,12 @@ def test_replace_compiled_regex(any_string_dtype): # test with compiled regex pat = re.compile(r"BAD_*") - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace(pat, "", regex=True) expected = Series(["foobar", np.nan], dtype=any_string_dtype) tm.assert_series_equal(result, expected) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace(pat, "", n=1, regex=True) expected = Series(["foobarBAD", np.nan], dtype=any_string_dtype) tm.assert_series_equal(result, expected) @@ -476,9 +470,7 @@ def test_replace_compiled_regex_unicode(any_string_dtype): ser = Series([b"abcd,\xc3\xa0".decode("utf-8")], dtype=any_string_dtype) expected = Series([b"abcd, \xc3\xa0".decode("utf-8")], dtype=any_string_dtype) pat = re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace(pat, ", ", regex=True) tm.assert_series_equal(result, expected) @@ -506,9 +498,7 @@ def test_replace_compiled_regex_callable(any_string_dtype): ser = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype) repl = lambda m: m.group(0).swapcase() pat = re.compile("[a-z][A-Z]{2}") - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace(pat, repl, n=2, regex=True) expected = Series(["foObaD__baRbaD", np.nan], dtype=any_string_dtype) tm.assert_series_equal(result, expected) @@ -557,9 +547,7 @@ def test_replace_moar(any_string_dtype): ) tm.assert_series_equal(result, expected) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace("A", "YYY", case=False) expected = Series( [ @@ -578,9 +566,7 @@ def test_replace_moar(any_string_dtype): ) tm.assert_series_equal(result, expected) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace("^.a|dog", "XX-XX ", case=False, regex=True) expected = Series( [ @@ -604,16 +590,12 @@ def test_replace_not_case_sensitive_not_regex(any_string_dtype): # https://github.com/pandas-dev/pandas/issues/41602 ser = Series(["A.", "a.", "Ab", "ab", np.nan], dtype=any_string_dtype) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace("a", "c", case=False, regex=False) expected = Series(["c.", "c.", "cb", "cb", np.nan], dtype=any_string_dtype) tm.assert_series_equal(result, expected) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.replace("a.", "c.", case=False, regex=False) expected = Series(["c.", "c.", "Ab", "ab", np.nan], dtype=any_string_dtype) tm.assert_series_equal(result, expected) @@ -761,9 +743,7 @@ def test_fullmatch_case_kwarg(any_string_dtype): result = ser.str.fullmatch("ab", case=False) tm.assert_series_equal(result, expected) - with tm.maybe_produces_warning( - PerformanceWarning, any_string_dtype == "string[pyarrow]" - ): + with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)): result = ser.str.fullmatch("ab", flags=re.IGNORECASE) tm.assert_series_equal(result, expected) @@ -944,16 +924,16 @@ def test_flags_kwarg(any_string_dtype): pat = r"([A-Z0-9._%+-]+)@([A-Z0-9.-]+)\.([A-Z]{2,4})" - using_pyarrow = any_string_dtype == "string[pyarrow]" + use_pyarrow = using_pyarrow(any_string_dtype) result = data.str.extract(pat, flags=re.IGNORECASE, expand=True) assert result.iloc[0].tolist() == ["dave", "google", "com"] - with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow): + with tm.maybe_produces_warning(PerformanceWarning, use_pyarrow): result = data.str.match(pat, flags=re.IGNORECASE) assert result.iloc[0] - with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow): + with tm.maybe_produces_warning(PerformanceWarning, use_pyarrow): result = data.str.fullmatch(pat, flags=re.IGNORECASE) assert result.iloc[0] @@ -965,7 +945,7 @@ def test_flags_kwarg(any_string_dtype): msg = "has match groups" with tm.assert_produces_warning( - UserWarning, match=msg, raise_on_extra_warnings=not using_pyarrow + UserWarning, match=msg, raise_on_extra_warnings=not use_pyarrow ): result = data.str.contains(pat, flags=re.IGNORECASE) assert result.iloc[0]