Skip to content

Commit cad55a1

Browse files
phoflmeeseeksmachine
authored andcommitted
Backport PR pandas-dev#54537: REF: Refactor using_pyarrow check for string tests
1 parent 64e4527 commit cad55a1

File tree

1 file changed

+21
-41
lines changed

1 file changed

+21
-41
lines changed

pandas/tests/strings/test_find_replace.py

+21-41
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
# --------------------------------------------------------------------------------------
1919

2020

21+
def using_pyarrow(dtype):
22+
return dtype in ("string[pyarrow]",)
23+
24+
2125
def test_contains(any_string_dtype):
2226
values = np.array(
2327
["foo", np.nan, "fooommm__foo", "mmm_", "foommm[_]+bar"], dtype=np.object_
@@ -379,9 +383,7 @@ def test_replace_mixed_object():
379383
def test_replace_unicode(any_string_dtype):
380384
ser = Series([b"abcd,\xc3\xa0".decode("utf-8")], dtype=any_string_dtype)
381385
expected = Series([b"abcd, \xc3\xa0".decode("utf-8")], dtype=any_string_dtype)
382-
with tm.maybe_produces_warning(
383-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
384-
):
386+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
385387
result = ser.str.replace(r"(?<=\w),(?=\w)", ", ", flags=re.UNICODE, regex=True)
386388
tm.assert_series_equal(result, expected)
387389

@@ -402,9 +404,7 @@ def test_replace_callable(any_string_dtype):
402404

403405
# test with callable
404406
repl = lambda m: m.group(0).swapcase()
405-
with tm.maybe_produces_warning(
406-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
407-
):
407+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
408408
result = ser.str.replace("[a-z][A-Z]{2}", repl, n=2, regex=True)
409409
expected = Series(["foObaD__baRbaD", np.nan], dtype=any_string_dtype)
410410
tm.assert_series_equal(result, expected)
@@ -424,7 +424,7 @@ def test_replace_callable_raises(any_string_dtype, repl):
424424
)
425425
with pytest.raises(TypeError, match=msg):
426426
with tm.maybe_produces_warning(
427-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
427+
PerformanceWarning, using_pyarrow(any_string_dtype)
428428
):
429429
values.str.replace("a", repl, regex=True)
430430

@@ -434,9 +434,7 @@ def test_replace_callable_named_groups(any_string_dtype):
434434
ser = Series(["Foo Bar Baz", np.nan], dtype=any_string_dtype)
435435
pat = r"(?P<first>\w+) (?P<middle>\w+) (?P<last>\w+)"
436436
repl = lambda m: m.group("middle").swapcase()
437-
with tm.maybe_produces_warning(
438-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
439-
):
437+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
440438
result = ser.str.replace(pat, repl, regex=True)
441439
expected = Series(["bAR", np.nan], dtype=any_string_dtype)
442440
tm.assert_series_equal(result, expected)
@@ -448,16 +446,12 @@ def test_replace_compiled_regex(any_string_dtype):
448446

449447
# test with compiled regex
450448
pat = re.compile(r"BAD_*")
451-
with tm.maybe_produces_warning(
452-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
453-
):
449+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
454450
result = ser.str.replace(pat, "", regex=True)
455451
expected = Series(["foobar", np.nan], dtype=any_string_dtype)
456452
tm.assert_series_equal(result, expected)
457453

458-
with tm.maybe_produces_warning(
459-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
460-
):
454+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
461455
result = ser.str.replace(pat, "", n=1, regex=True)
462456
expected = Series(["foobarBAD", np.nan], dtype=any_string_dtype)
463457
tm.assert_series_equal(result, expected)
@@ -477,9 +471,7 @@ def test_replace_compiled_regex_unicode(any_string_dtype):
477471
ser = Series([b"abcd,\xc3\xa0".decode("utf-8")], dtype=any_string_dtype)
478472
expected = Series([b"abcd, \xc3\xa0".decode("utf-8")], dtype=any_string_dtype)
479473
pat = re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE)
480-
with tm.maybe_produces_warning(
481-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
482-
):
474+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
483475
result = ser.str.replace(pat, ", ", regex=True)
484476
tm.assert_series_equal(result, expected)
485477

@@ -507,9 +499,7 @@ def test_replace_compiled_regex_callable(any_string_dtype):
507499
ser = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype)
508500
repl = lambda m: m.group(0).swapcase()
509501
pat = re.compile("[a-z][A-Z]{2}")
510-
with tm.maybe_produces_warning(
511-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
512-
):
502+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
513503
result = ser.str.replace(pat, repl, n=2, regex=True)
514504
expected = Series(["foObaD__baRbaD", np.nan], dtype=any_string_dtype)
515505
tm.assert_series_equal(result, expected)
@@ -558,9 +548,7 @@ def test_replace_moar(any_string_dtype):
558548
)
559549
tm.assert_series_equal(result, expected)
560550

561-
with tm.maybe_produces_warning(
562-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
563-
):
551+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
564552
result = ser.str.replace("A", "YYY", case=False)
565553
expected = Series(
566554
[
@@ -579,9 +567,7 @@ def test_replace_moar(any_string_dtype):
579567
)
580568
tm.assert_series_equal(result, expected)
581569

582-
with tm.maybe_produces_warning(
583-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
584-
):
570+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
585571
result = ser.str.replace("^.a|dog", "XX-XX ", case=False, regex=True)
586572
expected = Series(
587573
[
@@ -605,16 +591,12 @@ def test_replace_not_case_sensitive_not_regex(any_string_dtype):
605591
# https://github.com/pandas-dev/pandas/issues/41602
606592
ser = Series(["A.", "a.", "Ab", "ab", np.nan], dtype=any_string_dtype)
607593

608-
with tm.maybe_produces_warning(
609-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
610-
):
594+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
611595
result = ser.str.replace("a", "c", case=False, regex=False)
612596
expected = Series(["c.", "c.", "cb", "cb", np.nan], dtype=any_string_dtype)
613597
tm.assert_series_equal(result, expected)
614598

615-
with tm.maybe_produces_warning(
616-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
617-
):
599+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
618600
result = ser.str.replace("a.", "c.", case=False, regex=False)
619601
expected = Series(["c.", "c.", "Ab", "ab", np.nan], dtype=any_string_dtype)
620602
tm.assert_series_equal(result, expected)
@@ -762,9 +744,7 @@ def test_fullmatch_case_kwarg(any_string_dtype):
762744
result = ser.str.fullmatch("ab", case=False)
763745
tm.assert_series_equal(result, expected)
764746

765-
with tm.maybe_produces_warning(
766-
PerformanceWarning, any_string_dtype == "string[pyarrow]"
767-
):
747+
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow(any_string_dtype)):
768748
result = ser.str.fullmatch("ab", flags=re.IGNORECASE)
769749
tm.assert_series_equal(result, expected)
770750

@@ -945,16 +925,16 @@ def test_flags_kwarg(any_string_dtype):
945925

946926
pat = r"([A-Z0-9._%+-]+)@([A-Z0-9.-]+)\.([A-Z]{2,4})"
947927

948-
using_pyarrow = any_string_dtype == "string[pyarrow]"
928+
use_pyarrow = using_pyarrow(any_string_dtype)
949929

950930
result = data.str.extract(pat, flags=re.IGNORECASE, expand=True)
951931
assert result.iloc[0].tolist() == ["dave", "google", "com"]
952932

953-
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow):
933+
with tm.maybe_produces_warning(PerformanceWarning, use_pyarrow):
954934
result = data.str.match(pat, flags=re.IGNORECASE)
955935
assert result.iloc[0]
956936

957-
with tm.maybe_produces_warning(PerformanceWarning, using_pyarrow):
937+
with tm.maybe_produces_warning(PerformanceWarning, use_pyarrow):
958938
result = data.str.fullmatch(pat, flags=re.IGNORECASE)
959939
assert result.iloc[0]
960940

@@ -966,7 +946,7 @@ def test_flags_kwarg(any_string_dtype):
966946

967947
msg = "has match groups"
968948
with tm.assert_produces_warning(
969-
UserWarning, match=msg, raise_on_extra_warnings=not using_pyarrow
949+
UserWarning, match=msg, raise_on_extra_warnings=not use_pyarrow
970950
):
971951
result = data.str.contains(pat, flags=re.IGNORECASE)
972952
assert result.iloc[0]

0 commit comments

Comments
 (0)