Skip to content

Commit 6d59ae9

Browse files
[ArrowStringArray] TST: parametrize tests/strings/test_find_replace.py (pandas-dev#41471)
1 parent 6641587 commit 6d59ae9

File tree

1 file changed

+101
-85
lines changed

1 file changed

+101
-85
lines changed

pandas/tests/strings/test_find_replace.py

+101-85
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import pandas as pd
88
from pandas import (
9-
Index,
109
Series,
1110
_testing as tm,
1211
)
@@ -273,15 +272,14 @@ def test_replace_unicode(any_string_dtype):
273272
tm.assert_series_equal(result, expected)
274273

275274

276-
@pytest.mark.parametrize("klass", [Series, Index])
277275
@pytest.mark.parametrize("repl", [None, 3, {"a": "b"}])
278276
@pytest.mark.parametrize("data", [["a", "b", None], ["a", "b", "c", "ad"]])
279-
def test_replace_raises(any_string_dtype, klass, repl, data):
277+
def test_replace_raises(any_string_dtype, index_or_series, repl, data):
280278
# https://github.com/pandas-dev/pandas/issues/13438
281279
msg = "repl must be a string or callable"
282-
values = klass(data, dtype=any_string_dtype)
280+
obj = index_or_series(data, dtype=any_string_dtype)
283281
with pytest.raises(TypeError, match=msg):
284-
values.str.replace("a", repl)
282+
obj.str.replace("a", repl)
285283

286284

287285
def test_replace_callable(any_string_dtype):
@@ -486,39 +484,32 @@ def test_match_case_kwarg(any_string_dtype):
486484
tm.assert_series_equal(result, expected)
487485

488486

489-
def test_fullmatch():
487+
def test_fullmatch(any_string_dtype):
490488
# GH 32806
491-
ser = Series(["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"])
489+
ser = Series(
490+
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
491+
)
492492
result = ser.str.fullmatch(".*BAD[_]+.*BAD")
493-
expected = Series([True, False, np.nan, False])
493+
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
494+
expected = Series([True, False, np.nan, False], dtype=expected_dtype)
494495
tm.assert_series_equal(result, expected)
495496

496-
ser = Series(["ab", "AB", "abc", "ABC"])
497+
ser = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
497498
result = ser.str.fullmatch("ab", case=False)
498-
expected = Series([True, True, False, False])
499+
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
500+
expected = Series([True, True, False, False], dtype=expected_dtype)
499501
tm.assert_series_equal(result, expected)
500502

501503

502-
def test_fullmatch_nullable_string_dtype(nullable_string_dtype):
503-
ser = Series(
504-
["fooBAD__barBAD", "BAD_BADleroybrown", None, "foo"],
505-
dtype=nullable_string_dtype,
506-
)
507-
result = ser.str.fullmatch(".*BAD[_]+.*BAD")
508-
# Result is nullable boolean
509-
expected = Series([True, False, np.nan, False], dtype="boolean")
504+
def test_findall(any_string_dtype):
505+
ser = Series(["fooBAD__barBAD", np.nan, "foo", "BAD"], dtype=any_string_dtype)
506+
result = ser.str.findall("BAD[_]*")
507+
expected = Series([["BAD__", "BAD"], np.nan, [], ["BAD"]])
510508
tm.assert_series_equal(result, expected)
511509

512510

513-
def test_findall():
514-
values = Series(["fooBAD__barBAD", np.nan, "foo", "BAD"])
515-
516-
result = values.str.findall("BAD[_]*")
517-
exp = Series([["BAD__", "BAD"], np.nan, [], ["BAD"]])
518-
tm.assert_almost_equal(result, exp)
519-
520-
# mixed
521-
mixed = Series(
511+
def test_findall_mixed_object():
512+
ser = Series(
522513
[
523514
"fooBAD__barBAD",
524515
np.nan,
@@ -532,8 +523,8 @@ def test_findall():
532523
]
533524
)
534525

535-
rs = Series(mixed).str.findall("BAD[_]*")
536-
xp = Series(
526+
result = ser.str.findall("BAD[_]*")
527+
expected = Series(
537528
[
538529
["BAD__", "BAD"],
539530
np.nan,
@@ -547,86 +538,111 @@ def test_findall():
547538
]
548539
)
549540

550-
assert isinstance(rs, Series)
551-
tm.assert_almost_equal(rs, xp)
541+
tm.assert_series_equal(result, expected)
552542

553543

554-
def test_find():
555-
values = Series(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXXX"])
556-
result = values.str.find("EF")
557-
tm.assert_series_equal(result, Series([4, 3, 1, 0, -1]))
558-
expected = np.array([v.find("EF") for v in values.values], dtype=np.int64)
559-
tm.assert_numpy_array_equal(result.values, expected)
544+
def test_find(any_string_dtype):
545+
ser = Series(
546+
["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXXX"], dtype=any_string_dtype
547+
)
548+
expected_dtype = np.int64 if any_string_dtype == "object" else "Int64"
560549

561-
result = values.str.rfind("EF")
562-
tm.assert_series_equal(result, Series([4, 5, 7, 4, -1]))
563-
expected = np.array([v.rfind("EF") for v in values.values], dtype=np.int64)
564-
tm.assert_numpy_array_equal(result.values, expected)
550+
result = ser.str.find("EF")
551+
expected = Series([4, 3, 1, 0, -1], dtype=expected_dtype)
552+
tm.assert_series_equal(result, expected)
553+
expected = np.array([v.find("EF") for v in np.array(ser)], dtype=np.int64)
554+
tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected)
565555

566-
result = values.str.find("EF", 3)
567-
tm.assert_series_equal(result, Series([4, 3, 7, 4, -1]))
568-
expected = np.array([v.find("EF", 3) for v in values.values], dtype=np.int64)
569-
tm.assert_numpy_array_equal(result.values, expected)
556+
result = ser.str.rfind("EF")
557+
expected = Series([4, 5, 7, 4, -1], dtype=expected_dtype)
558+
tm.assert_series_equal(result, expected)
559+
expected = np.array([v.rfind("EF") for v in np.array(ser)], dtype=np.int64)
560+
tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected)
561+
562+
result = ser.str.find("EF", 3)
563+
expected = Series([4, 3, 7, 4, -1], dtype=expected_dtype)
564+
tm.assert_series_equal(result, expected)
565+
expected = np.array([v.find("EF", 3) for v in np.array(ser)], dtype=np.int64)
566+
tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected)
567+
568+
result = ser.str.rfind("EF", 3)
569+
expected = Series([4, 5, 7, 4, -1], dtype=expected_dtype)
570+
tm.assert_series_equal(result, expected)
571+
expected = np.array([v.rfind("EF", 3) for v in np.array(ser)], dtype=np.int64)
572+
tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected)
570573

571-
result = values.str.rfind("EF", 3)
572-
tm.assert_series_equal(result, Series([4, 5, 7, 4, -1]))
573-
expected = np.array([v.rfind("EF", 3) for v in values.values], dtype=np.int64)
574-
tm.assert_numpy_array_equal(result.values, expected)
574+
result = ser.str.find("EF", 3, 6)
575+
expected = Series([4, 3, -1, 4, -1], dtype=expected_dtype)
576+
tm.assert_series_equal(result, expected)
577+
expected = np.array([v.find("EF", 3, 6) for v in np.array(ser)], dtype=np.int64)
578+
tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected)
575579

576-
result = values.str.find("EF", 3, 6)
577-
tm.assert_series_equal(result, Series([4, 3, -1, 4, -1]))
578-
expected = np.array([v.find("EF", 3, 6) for v in values.values], dtype=np.int64)
579-
tm.assert_numpy_array_equal(result.values, expected)
580+
result = ser.str.rfind("EF", 3, 6)
581+
expected = Series([4, 3, -1, 4, -1], dtype=expected_dtype)
582+
tm.assert_series_equal(result, expected)
583+
expected = np.array([v.rfind("EF", 3, 6) for v in np.array(ser)], dtype=np.int64)
584+
tm.assert_numpy_array_equal(np.array(result, dtype=np.int64), expected)
580585

581-
result = values.str.rfind("EF", 3, 6)
582-
tm.assert_series_equal(result, Series([4, 3, -1, 4, -1]))
583-
expected = np.array([v.rfind("EF", 3, 6) for v in values.values], dtype=np.int64)
584-
tm.assert_numpy_array_equal(result.values, expected)
585586

587+
def test_find_bad_arg_raises(any_string_dtype):
588+
ser = Series([], dtype=any_string_dtype)
586589
with pytest.raises(TypeError, match="expected a string object, not int"):
587-
result = values.str.find(0)
590+
ser.str.find(0)
588591

589592
with pytest.raises(TypeError, match="expected a string object, not int"):
590-
result = values.str.rfind(0)
593+
ser.str.rfind(0)
591594

592595

593-
def test_find_nan():
594-
values = Series(["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"])
595-
result = values.str.find("EF")
596-
tm.assert_series_equal(result, Series([4, np.nan, 1, np.nan, -1]))
596+
def test_find_nan(any_string_dtype):
597+
ser = Series(
598+
["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype
599+
)
600+
expected_dtype = np.float64 if any_string_dtype == "object" else "Int64"
597601

598-
result = values.str.rfind("EF")
599-
tm.assert_series_equal(result, Series([4, np.nan, 7, np.nan, -1]))
602+
result = ser.str.find("EF")
603+
expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype)
604+
tm.assert_series_equal(result, expected)
600605

601-
result = values.str.find("EF", 3)
602-
tm.assert_series_equal(result, Series([4, np.nan, 7, np.nan, -1]))
606+
result = ser.str.rfind("EF")
607+
expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype)
608+
tm.assert_series_equal(result, expected)
603609

604-
result = values.str.rfind("EF", 3)
605-
tm.assert_series_equal(result, Series([4, np.nan, 7, np.nan, -1]))
610+
result = ser.str.find("EF", 3)
611+
expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype)
612+
tm.assert_series_equal(result, expected)
606613

607-
result = values.str.find("EF", 3, 6)
608-
tm.assert_series_equal(result, Series([4, np.nan, -1, np.nan, -1]))
614+
result = ser.str.rfind("EF", 3)
615+
expected = Series([4, np.nan, 7, np.nan, -1], dtype=expected_dtype)
616+
tm.assert_series_equal(result, expected)
609617

610-
result = values.str.rfind("EF", 3, 6)
611-
tm.assert_series_equal(result, Series([4, np.nan, -1, np.nan, -1]))
618+
result = ser.str.find("EF", 3, 6)
619+
expected = Series([4, np.nan, -1, np.nan, -1], dtype=expected_dtype)
620+
tm.assert_series_equal(result, expected)
612621

622+
result = ser.str.rfind("EF", 3, 6)
623+
expected = Series([4, np.nan, -1, np.nan, -1], dtype=expected_dtype)
624+
tm.assert_series_equal(result, expected)
613625

614-
def test_translate():
615-
def _check(result, expected):
616-
if isinstance(result, Series):
617-
tm.assert_series_equal(result, expected)
618-
else:
619-
tm.assert_index_equal(result, expected)
620626

621-
for klass in [Series, Index]:
622-
s = klass(["abcdefg", "abcc", "cdddfg", "cdefggg"])
623-
table = str.maketrans("abc", "cde")
624-
result = s.str.translate(table)
625-
expected = klass(["cdedefg", "cdee", "edddfg", "edefggg"])
626-
_check(result, expected)
627+
def test_translate(index_or_series, any_string_dtype):
628+
obj = index_or_series(
629+
["abcdefg", "abcc", "cdddfg", "cdefggg"], dtype=any_string_dtype
630+
)
631+
table = str.maketrans("abc", "cde")
632+
result = obj.str.translate(table)
633+
expected = index_or_series(
634+
["cdedefg", "cdee", "edddfg", "edefggg"], dtype=any_string_dtype
635+
)
636+
if index_or_series is Series:
637+
tm.assert_series_equal(result, expected)
638+
else:
639+
tm.assert_index_equal(result, expected)
640+
627641

642+
def test_translate_mixed_object():
628643
# Series with non-string values
629644
s = Series(["a", "b", "c", 1.2])
645+
table = str.maketrans("abc", "cde")
630646
expected = Series(["c", "d", "e", np.nan])
631647
result = s.str.translate(table)
632648
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)