File tree 3 files changed +41
-5
lines changed
tests/indexes/categorical
3 files changed +41
-5
lines changed Original file line number Diff line number Diff line change 37
37
from pandas .core .dtypes .dtypes import CategoricalDtype
38
38
from pandas .core .dtypes .generic import ABCIndexClass , ABCSeries
39
39
from pandas .core .dtypes .inference import is_hashable
40
- from pandas .core .dtypes .missing import isna , notna
40
+ from pandas .core .dtypes .missing import is_valid_nat_for_dtype , isna , notna
41
41
42
42
from pandas .core import ops
43
43
from pandas .core .accessor import PandasDelegate , delegate_names
@@ -1834,7 +1834,7 @@ def __contains__(self, key) -> bool:
1834
1834
Returns True if `key` is in this Categorical.
1835
1835
"""
1836
1836
# if key is a NaN, check if any NaN is in self.
1837
- if is_scalar (key ) and isna ( key ):
1837
+ if is_valid_nat_for_dtype (key , self . categories . dtype ):
1838
1838
return self .isna ().any ()
1839
1839
1840
1840
return contains (self , key , container = self ._codes )
Original file line number Diff line number Diff line change 19
19
is_scalar ,
20
20
)
21
21
from pandas .core .dtypes .dtypes import CategoricalDtype
22
- from pandas .core .dtypes .missing import isna
22
+ from pandas .core .dtypes .missing import is_valid_nat_for_dtype , isna
23
23
24
24
from pandas .core import accessor
25
25
from pandas .core .algorithms import take_1d
@@ -365,10 +365,9 @@ def _has_complex_internals(self) -> bool:
365
365
@doc (Index .__contains__ )
366
366
def __contains__ (self , key : Any ) -> bool :
367
367
# if key is a NaN, check if any NaN is in self.
368
- if is_scalar (key ) and isna ( key ):
368
+ if is_valid_nat_for_dtype (key , self . categories . dtype ):
369
369
return self .hasnans
370
370
371
- hash (key )
372
371
return contains (self , key , container = self ._engine )
373
372
374
373
@doc (Index .astype )
Original file line number Diff line number Diff line change @@ -285,6 +285,43 @@ def test_contains_nan(self):
285
285
ci = CategoricalIndex (list ("aabbca" ) + [np .nan ], categories = list ("cabdef" ))
286
286
assert np .nan in ci
287
287
288
+ @pytest .mark .parametrize ("unwrap" , [True , False ])
289
+ def test_contains_na_dtype (self , unwrap ):
290
+ dti = pd .date_range ("2016-01-01" , periods = 100 ).insert (0 , pd .NaT )
291
+ pi = dti .to_period ("D" )
292
+ tdi = dti - dti [- 1 ]
293
+ ci = CategoricalIndex (dti )
294
+
295
+ obj = ci
296
+ if unwrap :
297
+ obj = ci ._data
298
+
299
+ assert np .nan in obj
300
+ assert None in obj
301
+ assert pd .NaT in obj
302
+ assert np .datetime64 ("NaT" ) in obj
303
+ assert np .timedelta64 ("NaT" ) not in obj
304
+
305
+ obj2 = CategoricalIndex (tdi )
306
+ if unwrap :
307
+ obj2 = obj2 ._data
308
+
309
+ assert np .nan in obj2
310
+ assert None in obj2
311
+ assert pd .NaT in obj2
312
+ assert np .datetime64 ("NaT" ) not in obj2
313
+ assert np .timedelta64 ("NaT" ) in obj2
314
+
315
+ obj3 = CategoricalIndex (pi )
316
+ if unwrap :
317
+ obj3 = obj3 ._data
318
+
319
+ assert np .nan in obj3
320
+ assert None in obj3
321
+ assert pd .NaT in obj3
322
+ assert np .datetime64 ("NaT" ) not in obj3
323
+ assert np .timedelta64 ("NaT" ) not in obj3
324
+
288
325
@pytest .mark .parametrize (
289
326
"item, expected" ,
290
327
[
You can’t perform that action at this time.
0 commit comments