Skip to content

Commit 4072ed5

Browse files
pulkitmalooPingviinituutti
authored andcommitted
BUG: fixed .str.contains(..., na=False) for categorical series (pandas-dev#22170)
1 parent a5ad5fc commit 4072ed5

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

doc/source/whatsnew/v0.24.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,7 @@ Strings
12801280

12811281
- Bug in :meth:`Index.str.partition` was not nan-safe (:issue:`23558`).
12821282
- Bug in :meth:`Index.str.split` was not nan-safe (:issue:`23677`).
1283-
-
1283+
- Bug :func:`Series.str.contains` not respecting the ``na`` argument for a ``Categorical`` dtype ``Series`` (:issue:`22158`)
12841284

12851285
Interval
12861286
^^^^^^^^

pandas/core/strings.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1857,7 +1857,7 @@ def __iter__(self):
18571857
g = self.get(i)
18581858

18591859
def _wrap_result(self, result, use_codes=True,
1860-
name=None, expand=None):
1860+
name=None, expand=None, fill_value=np.nan):
18611861

18621862
from pandas.core.index import Index, MultiIndex
18631863

@@ -1867,7 +1867,8 @@ def _wrap_result(self, result, use_codes=True,
18671867
# so make it possible to skip this step as the method already did this
18681868
# before the transformation...
18691869
if use_codes and self._is_categorical:
1870-
result = take_1d(result, self._orig.cat.codes)
1870+
result = take_1d(result, self._orig.cat.codes,
1871+
fill_value=fill_value)
18711872

18721873
if not hasattr(result, 'ndim') or not hasattr(result, 'dtype'):
18731874
return result
@@ -2520,12 +2521,12 @@ def join(self, sep):
25202521
def contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
25212522
result = str_contains(self._parent, pat, case=case, flags=flags, na=na,
25222523
regex=regex)
2523-
return self._wrap_result(result)
2524+
return self._wrap_result(result, fill_value=na)
25242525

25252526
@copy(str_match)
25262527
def match(self, pat, case=True, flags=0, na=np.nan):
25272528
result = str_match(self._parent, pat, case=case, flags=flags, na=na)
2528-
return self._wrap_result(result)
2529+
return self._wrap_result(result, fill_value=na)
25292530

25302531
@copy(str_replace)
25312532
def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True):

pandas/tests/test_strings.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -512,10 +512,28 @@ def test_contains(self):
512512
assert result.dtype == np.bool_
513513
tm.assert_numpy_array_equal(result, expected)
514514

515-
# na
516-
values = Series(['om', 'foo', np.nan])
517-
res = values.str.contains('foo', na="foo")
518-
assert res.loc[2] == "foo"
515+
def test_contains_for_object_category(self):
516+
# gh 22158
517+
518+
# na for category
519+
values = Series(["a", "b", "c", "a", np.nan], dtype="category")
520+
result = values.str.contains('a', na=True)
521+
expected = Series([True, False, False, True, True])
522+
tm.assert_series_equal(result, expected)
523+
524+
result = values.str.contains('a', na=False)
525+
expected = Series([True, False, False, True, False])
526+
tm.assert_series_equal(result, expected)
527+
528+
# na for objects
529+
values = Series(["a", "b", "c", "a", np.nan])
530+
result = values.str.contains('a', na=True)
531+
expected = Series([True, False, False, True, True])
532+
tm.assert_series_equal(result, expected)
533+
534+
result = values.str.contains('a', na=False)
535+
expected = Series([True, False, False, True, False])
536+
tm.assert_series_equal(result, expected)
519537

520538
def test_startswith(self):
521539
values = Series(['om', NA, 'foo_nom', 'nom', 'bar_foo', NA, 'foo'])
@@ -2893,7 +2911,7 @@ def test_get_complex_nested(self, to_type):
28932911
expected = Series([np.nan])
28942912
tm.assert_series_equal(result, expected)
28952913

2896-
def test_more_contains(self):
2914+
def test_contains_moar(self):
28972915
# PR #1179
28982916
s = Series(['A', 'B', 'C', 'Aaba', 'Baca', '', NA,
28992917
'CABA', 'dog', 'cat'])
@@ -2943,7 +2961,7 @@ def test_contains_nan(self):
29432961
expected = Series([np.nan, np.nan, np.nan], dtype=np.object_)
29442962
assert_series_equal(result, expected)
29452963

2946-
def test_more_replace(self):
2964+
def test_replace_moar(self):
29472965
# PR #1179
29482966
s = Series(['A', 'B', 'C', 'Aaba', 'Baca', '', NA, 'CABA',
29492967
'dog', 'cat'])

0 commit comments

Comments
 (0)