18
18
# --------------------------------------------------------------------------------------
19
19
20
20
21
+ def using_pyarrow (dtype ):
22
+ return dtype in ("string[pyarrow]" ,)
23
+
24
+
21
25
def test_contains (any_string_dtype ):
22
26
values = np .array (
23
27
["foo" , np .nan , "fooommm__foo" , "mmm_" , "foommm[_]+bar" ], dtype = np .object_
@@ -379,9 +383,7 @@ def test_replace_mixed_object():
379
383
def test_replace_unicode (any_string_dtype ):
380
384
ser = Series ([b"abcd,\xc3 \xa0 " .decode ("utf-8" )], dtype = any_string_dtype )
381
385
expected = Series ([b"abcd, \xc3 \xa0 " .decode ("utf-8" )], dtype = any_string_dtype )
382
- with tm .maybe_produces_warning (
383
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
384
- ):
386
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
385
387
result = ser .str .replace (r"(?<=\w),(?=\w)" , ", " , flags = re .UNICODE , regex = True )
386
388
tm .assert_series_equal (result , expected )
387
389
@@ -402,9 +404,7 @@ def test_replace_callable(any_string_dtype):
402
404
403
405
# test with callable
404
406
repl = lambda m : m .group (0 ).swapcase ()
405
- with tm .maybe_produces_warning (
406
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
407
- ):
407
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
408
408
result = ser .str .replace ("[a-z][A-Z]{2}" , repl , n = 2 , regex = True )
409
409
expected = Series (["foObaD__baRbaD" , np .nan ], dtype = any_string_dtype )
410
410
tm .assert_series_equal (result , expected )
@@ -424,7 +424,7 @@ def test_replace_callable_raises(any_string_dtype, repl):
424
424
)
425
425
with pytest .raises (TypeError , match = msg ):
426
426
with tm .maybe_produces_warning (
427
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
427
+ PerformanceWarning , using_pyarrow ( any_string_dtype )
428
428
):
429
429
values .str .replace ("a" , repl , regex = True )
430
430
@@ -434,9 +434,7 @@ def test_replace_callable_named_groups(any_string_dtype):
434
434
ser = Series (["Foo Bar Baz" , np .nan ], dtype = any_string_dtype )
435
435
pat = r"(?P<first>\w+) (?P<middle>\w+) (?P<last>\w+)"
436
436
repl = lambda m : m .group ("middle" ).swapcase ()
437
- with tm .maybe_produces_warning (
438
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
439
- ):
437
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
440
438
result = ser .str .replace (pat , repl , regex = True )
441
439
expected = Series (["bAR" , np .nan ], dtype = any_string_dtype )
442
440
tm .assert_series_equal (result , expected )
@@ -448,16 +446,12 @@ def test_replace_compiled_regex(any_string_dtype):
448
446
449
447
# test with compiled regex
450
448
pat = re .compile (r"BAD_*" )
451
- with tm .maybe_produces_warning (
452
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
453
- ):
449
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
454
450
result = ser .str .replace (pat , "" , regex = True )
455
451
expected = Series (["foobar" , np .nan ], dtype = any_string_dtype )
456
452
tm .assert_series_equal (result , expected )
457
453
458
- with tm .maybe_produces_warning (
459
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
460
- ):
454
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
461
455
result = ser .str .replace (pat , "" , n = 1 , regex = True )
462
456
expected = Series (["foobarBAD" , np .nan ], dtype = any_string_dtype )
463
457
tm .assert_series_equal (result , expected )
@@ -477,9 +471,7 @@ def test_replace_compiled_regex_unicode(any_string_dtype):
477
471
ser = Series ([b"abcd,\xc3 \xa0 " .decode ("utf-8" )], dtype = any_string_dtype )
478
472
expected = Series ([b"abcd, \xc3 \xa0 " .decode ("utf-8" )], dtype = any_string_dtype )
479
473
pat = re .compile (r"(?<=\w),(?=\w)" , flags = re .UNICODE )
480
- with tm .maybe_produces_warning (
481
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
482
- ):
474
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
483
475
result = ser .str .replace (pat , ", " , regex = True )
484
476
tm .assert_series_equal (result , expected )
485
477
@@ -507,9 +499,7 @@ def test_replace_compiled_regex_callable(any_string_dtype):
507
499
ser = Series (["fooBAD__barBAD" , np .nan ], dtype = any_string_dtype )
508
500
repl = lambda m : m .group (0 ).swapcase ()
509
501
pat = re .compile ("[a-z][A-Z]{2}" )
510
- with tm .maybe_produces_warning (
511
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
512
- ):
502
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
513
503
result = ser .str .replace (pat , repl , n = 2 , regex = True )
514
504
expected = Series (["foObaD__baRbaD" , np .nan ], dtype = any_string_dtype )
515
505
tm .assert_series_equal (result , expected )
@@ -558,9 +548,7 @@ def test_replace_moar(any_string_dtype):
558
548
)
559
549
tm .assert_series_equal (result , expected )
560
550
561
- with tm .maybe_produces_warning (
562
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
563
- ):
551
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
564
552
result = ser .str .replace ("A" , "YYY" , case = False )
565
553
expected = Series (
566
554
[
@@ -579,9 +567,7 @@ def test_replace_moar(any_string_dtype):
579
567
)
580
568
tm .assert_series_equal (result , expected )
581
569
582
- with tm .maybe_produces_warning (
583
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
584
- ):
570
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
585
571
result = ser .str .replace ("^.a|dog" , "XX-XX " , case = False , regex = True )
586
572
expected = Series (
587
573
[
@@ -605,16 +591,12 @@ def test_replace_not_case_sensitive_not_regex(any_string_dtype):
605
591
# https://github.com/pandas-dev/pandas/issues/41602
606
592
ser = Series (["A." , "a." , "Ab" , "ab" , np .nan ], dtype = any_string_dtype )
607
593
608
- with tm .maybe_produces_warning (
609
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
610
- ):
594
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
611
595
result = ser .str .replace ("a" , "c" , case = False , regex = False )
612
596
expected = Series (["c." , "c." , "cb" , "cb" , np .nan ], dtype = any_string_dtype )
613
597
tm .assert_series_equal (result , expected )
614
598
615
- with tm .maybe_produces_warning (
616
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
617
- ):
599
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
618
600
result = ser .str .replace ("a." , "c." , case = False , regex = False )
619
601
expected = Series (["c." , "c." , "Ab" , "ab" , np .nan ], dtype = any_string_dtype )
620
602
tm .assert_series_equal (result , expected )
@@ -762,9 +744,7 @@ def test_fullmatch_case_kwarg(any_string_dtype):
762
744
result = ser .str .fullmatch ("ab" , case = False )
763
745
tm .assert_series_equal (result , expected )
764
746
765
- with tm .maybe_produces_warning (
766
- PerformanceWarning , any_string_dtype == "string[pyarrow]"
767
- ):
747
+ with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow (any_string_dtype )):
768
748
result = ser .str .fullmatch ("ab" , flags = re .IGNORECASE )
769
749
tm .assert_series_equal (result , expected )
770
750
@@ -945,16 +925,16 @@ def test_flags_kwarg(any_string_dtype):
945
925
946
926
pat = r"([A-Z0-9._%+-]+)@([A-Z0-9.-]+)\.([A-Z]{2,4})"
947
927
948
- using_pyarrow = any_string_dtype == "string[pyarrow]"
928
+ use_pyarrow = using_pyarrow ( any_string_dtype )
949
929
950
930
result = data .str .extract (pat , flags = re .IGNORECASE , expand = True )
951
931
assert result .iloc [0 ].tolist () == ["dave" , "google" , "com" ]
952
932
953
- with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow ):
933
+ with tm .maybe_produces_warning (PerformanceWarning , use_pyarrow ):
954
934
result = data .str .match (pat , flags = re .IGNORECASE )
955
935
assert result .iloc [0 ]
956
936
957
- with tm .maybe_produces_warning (PerformanceWarning , using_pyarrow ):
937
+ with tm .maybe_produces_warning (PerformanceWarning , use_pyarrow ):
958
938
result = data .str .fullmatch (pat , flags = re .IGNORECASE )
959
939
assert result .iloc [0 ]
960
940
@@ -966,7 +946,7 @@ def test_flags_kwarg(any_string_dtype):
966
946
967
947
msg = "has match groups"
968
948
with tm .assert_produces_warning (
969
- UserWarning , match = msg , raise_on_extra_warnings = not using_pyarrow
949
+ UserWarning , match = msg , raise_on_extra_warnings = not use_pyarrow
970
950
):
971
951
result = data .str .contains (pat , flags = re .IGNORECASE )
972
952
assert result .iloc [0 ]
0 commit comments