Skip to content

Commit ebb727e

Browse files
authored
BUG: CategoricalIndex.__contains__ incorrect NaTs (pandas-dev#33947)
1 parent c557ab5 commit ebb727e

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

pandas/core/arrays/categorical.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from pandas.core.dtypes.dtypes import CategoricalDtype
3838
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
3939
from pandas.core.dtypes.inference import is_hashable
40-
from pandas.core.dtypes.missing import isna, notna
40+
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna, notna
4141

4242
from pandas.core import ops
4343
from pandas.core.accessor import PandasDelegate, delegate_names
@@ -1834,7 +1834,7 @@ def __contains__(self, key) -> bool:
18341834
Returns True if `key` is in this Categorical.
18351835
"""
18361836
# if key is a NaN, check if any NaN is in self.
1837-
if is_scalar(key) and isna(key):
1837+
if is_valid_nat_for_dtype(key, self.categories.dtype):
18381838
return self.isna().any()
18391839

18401840
return contains(self, key, container=self._codes)

pandas/core/indexes/category.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
is_scalar,
2020
)
2121
from pandas.core.dtypes.dtypes import CategoricalDtype
22-
from pandas.core.dtypes.missing import isna
22+
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna
2323

2424
from pandas.core import accessor
2525
from pandas.core.algorithms import take_1d
@@ -365,10 +365,9 @@ def _has_complex_internals(self) -> bool:
365365
@doc(Index.__contains__)
366366
def __contains__(self, key: Any) -> bool:
367367
# if key is a NaN, check if any NaN is in self.
368-
if is_scalar(key) and isna(key):
368+
if is_valid_nat_for_dtype(key, self.categories.dtype):
369369
return self.hasnans
370370

371-
hash(key)
372371
return contains(self, key, container=self._engine)
373372

374373
@doc(Index.astype)

pandas/tests/indexes/categorical/test_indexing.py

+37
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,43 @@ def test_contains_nan(self):
285285
ci = CategoricalIndex(list("aabbca") + [np.nan], categories=list("cabdef"))
286286
assert np.nan in ci
287287

288+
@pytest.mark.parametrize("unwrap", [True, False])
289+
def test_contains_na_dtype(self, unwrap):
290+
dti = pd.date_range("2016-01-01", periods=100).insert(0, pd.NaT)
291+
pi = dti.to_period("D")
292+
tdi = dti - dti[-1]
293+
ci = CategoricalIndex(dti)
294+
295+
obj = ci
296+
if unwrap:
297+
obj = ci._data
298+
299+
assert np.nan in obj
300+
assert None in obj
301+
assert pd.NaT in obj
302+
assert np.datetime64("NaT") in obj
303+
assert np.timedelta64("NaT") not in obj
304+
305+
obj2 = CategoricalIndex(tdi)
306+
if unwrap:
307+
obj2 = obj2._data
308+
309+
assert np.nan in obj2
310+
assert None in obj2
311+
assert pd.NaT in obj2
312+
assert np.datetime64("NaT") not in obj2
313+
assert np.timedelta64("NaT") in obj2
314+
315+
obj3 = CategoricalIndex(pi)
316+
if unwrap:
317+
obj3 = obj3._data
318+
319+
assert np.nan in obj3
320+
assert None in obj3
321+
assert pd.NaT in obj3
322+
assert np.datetime64("NaT") not in obj3
323+
assert np.timedelta64("NaT") not in obj3
324+
288325
@pytest.mark.parametrize(
289326
"item, expected",
290327
[

0 commit comments

Comments
 (0)