diff --git a/doc/source/whatsnew/v1.1.3.rst b/doc/source/whatsnew/v1.1.3.rst index e3161012da5d1..c06990e3f2051 100644 --- a/doc/source/whatsnew/v1.1.3.rst +++ b/doc/source/whatsnew/v1.1.3.rst @@ -22,7 +22,7 @@ Fixed regressions Bug fixes ~~~~~~~~~ -- +- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with ``category`` dtype not propagating ``na`` parameter (:issue:`36241`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/strings.py b/pandas/core/strings.py index 6702bf519c52e..4decd86764ccc 100644 --- a/pandas/core/strings.py +++ b/pandas/core/strings.py @@ -2050,7 +2050,7 @@ def wrapper2(self, pat, flags=0, **kwargs): @forbid_nonstring_types(forbidden_types, name=name) def wrapper3(self, pat, na=np.nan): result = f(self._parent, pat, na=na) - return self._wrap_result(result, returns_string=returns_string) + return self._wrap_result(result, returns_string=returns_string, fill_value=na) wrapper = wrapper3 if na else wrapper2 if flags else wrapper1 diff --git a/pandas/tests/test_strings.py b/pandas/tests/test_strings.py index d9396d70f9112..c792a48d3ef08 100644 --- a/pandas/tests/test_strings.py +++ b/pandas/tests/test_strings.py @@ -29,6 +29,8 @@ def assert_series_or_index_equal(left, right): ("decode", ("UTF-8",), {}), ("encode", ("UTF-8",), {}), ("endswith", ("a",), {}), + ("endswith", ("a",), {"na": True}), + ("endswith", ("a",), {"na": False}), ("extract", ("([a-z]*)",), {"expand": False}), ("extract", ("([a-z]*)",), {"expand": True}), ("extractall", ("([a-z]*)",), {}), @@ -58,6 +60,8 @@ def assert_series_or_index_equal(left, right): ("split", (" ",), {"expand": False}), ("split", (" ",), {"expand": True}), ("startswith", ("a",), {}), + ("startswith", ("a",), {"na": True}), + ("startswith", ("a",), {"na": False}), # translating unicode points of "a" to "d" ("translate", ({97: 100},), {}), ("wrap", (2,), {}), @@ -838,15 +842,23 @@ def test_contains_for_object_category(self): expected = Series([True, False, False, True, False]) tm.assert_series_equal(result, expected) - def test_startswith(self): - values = Series(["om", np.nan, "foo_nom", "nom", "bar_foo", np.nan, "foo"]) + @pytest.mark.parametrize("dtype", [None, "category"]) + @pytest.mark.parametrize("null_value", [None, np.nan, pd.NA]) + @pytest.mark.parametrize("na", [True, False]) + def test_startswith(self, dtype, null_value, na): + # add category dtype parametrizations for GH-36241 + values = Series( + ["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"], + dtype=dtype, + ) result = values.str.startswith("foo") exp = Series([False, np.nan, True, False, False, np.nan, True]) tm.assert_series_equal(result, exp) - result = values.str.startswith("foo", na=True) - tm.assert_series_equal(result, exp.fillna(True).astype(bool)) + result = values.str.startswith("foo", na=na) + exp = Series([False, na, True, False, False, na, True]) + tm.assert_series_equal(result, exp) # mixed mixed = np.array( @@ -867,15 +879,23 @@ def test_startswith(self): ) tm.assert_series_equal(rs, xp) - def test_endswith(self): - values = Series(["om", np.nan, "foo_nom", "nom", "bar_foo", np.nan, "foo"]) + @pytest.mark.parametrize("dtype", [None, "category"]) + @pytest.mark.parametrize("null_value", [None, np.nan, pd.NA]) + @pytest.mark.parametrize("na", [True, False]) + def test_endswith(self, dtype, null_value, na): + # add category dtype parametrizations for GH-36241 + values = Series( + ["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"], + dtype=dtype, + ) result = values.str.endswith("foo") exp = Series([False, np.nan, False, False, True, np.nan, True]) tm.assert_series_equal(result, exp) - result = values.str.endswith("foo", na=False) - tm.assert_series_equal(result, exp.fillna(False).astype(bool)) + result = values.str.endswith("foo", na=na) + exp = Series([False, na, False, False, True, na, True]) + tm.assert_series_equal(result, exp) # mixed mixed = np.array( @@ -3552,6 +3572,10 @@ def test_string_array(any_string_method): assert result.dtype == "boolean" result = result.astype(object) + elif expected.dtype == "bool": + assert result.dtype == "boolean" + result = result.astype("bool") + elif expected.dtype == "float" and expected.isna().any(): assert result.dtype == "Int64" result = result.astype("float")