Skip to content

Commit 2bd1744

Browse files
asishmmeeseeksmachine
authored andcommitted
Backport PR pandas-dev#36249: BUG: na parameter for str.startswith and str.endswith not propagating for Series with categorical dtype
1 parent f757f62 commit 2bd1744

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

doc/source/whatsnew/v1.1.3.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Fixed regressions
2222

2323
Bug fixes
2424
~~~~~~~~~
25-
-
25+
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with ``category`` dtype not propagating ``na`` parameter (:issue:`36241`)
2626

2727
.. ---------------------------------------------------------------------------
2828

pandas/core/strings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2050,7 +2050,7 @@ def wrapper2(self, pat, flags=0, **kwargs):
20502050
@forbid_nonstring_types(forbidden_types, name=name)
20512051
def wrapper3(self, pat, na=np.nan):
20522052
result = f(self._parent, pat, na=na)
2053-
return self._wrap_result(result, returns_string=returns_string)
2053+
return self._wrap_result(result, returns_string=returns_string, fill_value=na)
20542054

20552055
wrapper = wrapper3 if na else wrapper2 if flags else wrapper1
20562056

pandas/tests/test_strings.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def assert_series_or_index_equal(left, right):
2929
("decode", ("UTF-8",), {}),
3030
("encode", ("UTF-8",), {}),
3131
("endswith", ("a",), {}),
32+
("endswith", ("a",), {"na": True}),
33+
("endswith", ("a",), {"na": False}),
3234
("extract", ("([a-z]*)",), {"expand": False}),
3335
("extract", ("([a-z]*)",), {"expand": True}),
3436
("extractall", ("([a-z]*)",), {}),
@@ -58,6 +60,8 @@ def assert_series_or_index_equal(left, right):
5860
("split", (" ",), {"expand": False}),
5961
("split", (" ",), {"expand": True}),
6062
("startswith", ("a",), {}),
63+
("startswith", ("a",), {"na": True}),
64+
("startswith", ("a",), {"na": False}),
6165
# translating unicode points of "a" to "d"
6266
("translate", ({97: 100},), {}),
6367
("wrap", (2,), {}),
@@ -838,15 +842,23 @@ def test_contains_for_object_category(self):
838842
expected = Series([True, False, False, True, False])
839843
tm.assert_series_equal(result, expected)
840844

841-
def test_startswith(self):
842-
values = Series(["om", np.nan, "foo_nom", "nom", "bar_foo", np.nan, "foo"])
845+
@pytest.mark.parametrize("dtype", [None, "category"])
846+
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
847+
@pytest.mark.parametrize("na", [True, False])
848+
def test_startswith(self, dtype, null_value, na):
849+
# add category dtype parametrizations for GH-36241
850+
values = Series(
851+
["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"],
852+
dtype=dtype,
853+
)
843854

844855
result = values.str.startswith("foo")
845856
exp = Series([False, np.nan, True, False, False, np.nan, True])
846857
tm.assert_series_equal(result, exp)
847858

848-
result = values.str.startswith("foo", na=True)
849-
tm.assert_series_equal(result, exp.fillna(True).astype(bool))
859+
result = values.str.startswith("foo", na=na)
860+
exp = Series([False, na, True, False, False, na, True])
861+
tm.assert_series_equal(result, exp)
850862

851863
# mixed
852864
mixed = np.array(
@@ -867,15 +879,23 @@ def test_startswith(self):
867879
)
868880
tm.assert_series_equal(rs, xp)
869881

870-
def test_endswith(self):
871-
values = Series(["om", np.nan, "foo_nom", "nom", "bar_foo", np.nan, "foo"])
882+
@pytest.mark.parametrize("dtype", [None, "category"])
883+
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
884+
@pytest.mark.parametrize("na", [True, False])
885+
def test_endswith(self, dtype, null_value, na):
886+
# add category dtype parametrizations for GH-36241
887+
values = Series(
888+
["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"],
889+
dtype=dtype,
890+
)
872891

873892
result = values.str.endswith("foo")
874893
exp = Series([False, np.nan, False, False, True, np.nan, True])
875894
tm.assert_series_equal(result, exp)
876895

877-
result = values.str.endswith("foo", na=False)
878-
tm.assert_series_equal(result, exp.fillna(False).astype(bool))
896+
result = values.str.endswith("foo", na=na)
897+
exp = Series([False, na, False, False, True, na, True])
898+
tm.assert_series_equal(result, exp)
879899

880900
# mixed
881901
mixed = np.array(
@@ -3552,6 +3572,10 @@ def test_string_array(any_string_method):
35523572
assert result.dtype == "boolean"
35533573
result = result.astype(object)
35543574

3575+
elif expected.dtype == "bool":
3576+
assert result.dtype == "boolean"
3577+
result = result.astype("bool")
3578+
35553579
elif expected.dtype == "float" and expected.isna().any():
35563580
assert result.dtype == "Int64"
35573581
result = result.astype("float")

0 commit comments

Comments
 (0)