diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 3e49d6367ffd9..af6ffcb2a9379 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -137,15 +137,11 @@ def test_repeat_mixed_object(): tm.assert_series_equal(result, expected) -def test_repeat_with_null(any_string_dtype): +@pytest.mark.parametrize("arg, repeat", [[None, 4], ["b", None]]) +def test_repeat_with_null(any_string_dtype, arg, repeat): # GH: 31632 - ser = Series(["a", None], dtype=any_string_dtype) - result = ser.str.repeat([3, 4]) - expected = Series(["aaa", np.nan], dtype=any_string_dtype) - tm.assert_series_equal(result, expected) - - ser = Series(["a", "b"], dtype=any_string_dtype) - result = ser.str.repeat([3, None]) + ser = Series(["a", arg], dtype=any_string_dtype) + result = ser.str.repeat([3, repeat]) expected = Series(["aaa", np.nan], dtype=any_string_dtype) tm.assert_series_equal(result, expected) @@ -397,27 +393,28 @@ def test_index_not_found_raises(index_or_series, any_string_dtype): obj.str.index("DE") -def test_index_wrong_type_raises(index_or_series, any_string_dtype): +@pytest.mark.parametrize("method", ["index", "rindex"]) +def test_index_wrong_type_raises(index_or_series, any_string_dtype, method): obj = index_or_series([], dtype=any_string_dtype) msg = "expected a string object, not int" with pytest.raises(TypeError, match=msg): - obj.str.index(0) - - with pytest.raises(TypeError, match=msg): - obj.str.rindex(0) + getattr(obj.str, method)(0) -def test_index_missing(any_string_dtype): +@pytest.mark.parametrize( + "method, exp", + [ + ["index", [1, 1, 0]], + ["rindex", [3, 1, 2]], + ], +) +def test_index_missing(any_string_dtype, method, exp): ser = Series(["abcb", "ab", "bcbe", np.nan], dtype=any_string_dtype) expected_dtype = np.float64 if any_string_dtype == "object" else "Int64" - result = ser.str.index("b") - expected = Series([1, 1, 0, np.nan], dtype=expected_dtype) - tm.assert_series_equal(result, expected) - - result = ser.str.rindex("b") - expected = Series([3, 1, 2, np.nan], dtype=expected_dtype) + result = getattr(ser.str, method)("b") + expected = Series(exp + [np.nan], dtype=expected_dtype) tm.assert_series_equal(result, expected) @@ -488,53 +485,51 @@ def test_slice_replace(start, stop, repl, expected, any_string_dtype): tm.assert_series_equal(result, expected) -def test_strip_lstrip_rstrip(any_string_dtype): +@pytest.mark.parametrize( + "method, exp", + [ + ["strip", ["aa", "bb", np.nan, "cc"]], + ["lstrip", ["aa ", "bb \n", np.nan, "cc "]], + ["rstrip", [" aa", " bb", np.nan, "cc"]], + ], +) +def test_strip_lstrip_rstrip(any_string_dtype, method, exp): ser = Series([" aa ", " bb \n", np.nan, "cc "], dtype=any_string_dtype) - result = ser.str.strip() - expected = Series(["aa", "bb", np.nan, "cc"], dtype=any_string_dtype) - tm.assert_series_equal(result, expected) - - result = ser.str.lstrip() - expected = Series(["aa ", "bb \n", np.nan, "cc "], dtype=any_string_dtype) - tm.assert_series_equal(result, expected) - - result = ser.str.rstrip() - expected = Series([" aa", " bb", np.nan, "cc"], dtype=any_string_dtype) + result = getattr(ser.str, method)() + expected = Series(exp, dtype=any_string_dtype) tm.assert_series_equal(result, expected) -def test_strip_lstrip_rstrip_mixed_object(): +@pytest.mark.parametrize( + "method, exp", + [ + ["strip", ["aa", np.nan, "bb"]], + ["lstrip", ["aa ", np.nan, "bb \t\n"]], + ["rstrip", [" aa", np.nan, " bb"]], + ], +) +def test_strip_lstrip_rstrip_mixed_object(method, exp): ser = Series([" aa ", np.nan, " bb \t\n", True, datetime.today(), None, 1, 2.0]) - result = ser.str.strip() - expected = Series(["aa", np.nan, "bb", np.nan, np.nan, np.nan, np.nan, np.nan]) - tm.assert_series_equal(result, expected) - - result = ser.str.lstrip() - expected = Series( - ["aa ", np.nan, "bb \t\n", np.nan, np.nan, np.nan, np.nan, np.nan] - ) - tm.assert_series_equal(result, expected) - - result = ser.str.rstrip() - expected = Series([" aa", np.nan, " bb", np.nan, np.nan, np.nan, np.nan, np.nan]) + result = getattr(ser.str, method)() + expected = Series(exp + [np.nan, np.nan, np.nan, np.nan, np.nan]) tm.assert_series_equal(result, expected) -def test_strip_lstrip_rstrip_args(any_string_dtype): +@pytest.mark.parametrize( + "method, exp", + [ + ["strip", ["ABC", " BNSD", "LDFJH "]], + ["lstrip", ["ABCxx", " BNSD", "LDFJH xx"]], + ["rstrip", ["xxABC", "xx BNSD", "LDFJH "]], + ], +) +def test_strip_lstrip_rstrip_args(any_string_dtype, method, exp): ser = Series(["xxABCxx", "xx BNSD", "LDFJH xx"], dtype=any_string_dtype) - result = ser.str.strip("x") - expected = Series(["ABC", " BNSD", "LDFJH "], dtype=any_string_dtype) - tm.assert_series_equal(result, expected) - - result = ser.str.lstrip("x") - expected = Series(["ABCxx", " BNSD", "LDFJH xx"], dtype=any_string_dtype) - tm.assert_series_equal(result, expected) - - result = ser.str.rstrip("x") - expected = Series(["xxABC", "xx BNSD", "LDFJH "], dtype=any_string_dtype) + result = getattr(ser.str, method)("x") + expected = Series(exp, dtype=any_string_dtype) tm.assert_series_equal(result, expected)