Skip to content

Commit 547388c

Browse files
committed
various
1 parent cf1791e commit 547388c

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

pandas/core/arrays/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from pandas.core.dtypes.dtypes import ExtensionDtype
4343
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
44-
from pandas.core.dtypes.missing import isna
44+
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna
4545

4646
from pandas.core import ops
4747
from pandas.core.algorithms import factorize_array, unique
@@ -360,7 +360,7 @@ def __contains__(self, item) -> bool:
360360
"""
361361
# comparisons of any item to pd.NA always return pd.NA, so e.g. "a" in [pd.NA]
362362
# would raise a TypeError. The implementation below works around that.
363-
if isna(item):
363+
if is_valid_nat_for_dtype(item, self.dtype):
364364
return isna(self).any() if self._can_hold_na else False
365365
else:
366366
return (item == self).any()

pandas/tests/arrays/categorical/test_operators.py

+3
Original file line numberDiff line numberDiff line change
@@ -401,14 +401,17 @@ def test_contains(self, ordered):
401401
cat = Categorical(["a", "b"], ordered=ordered)
402402
assert "a" in cat
403403
assert "x" not in cat
404+
assert np.nan not in cat
404405
assert pd.NA not in cat
405406

406407
cat = Categorical([np.nan, "a"], ordered=ordered)
407408
assert "a" in cat
408409
assert "x" not in cat
410+
assert np.nan in cat
409411
assert pd.NA in cat
410412

411413
cat = cat[::-1]
412414
assert "a" in cat
413415
assert "x" not in cat
416+
assert np.nan in cat
414417
assert pd.NA in cat

pandas/tests/arrays/string_/test_string.py

+16
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,19 @@ def test_to_numpy_na_value(dtype, nulls_fixture):
524524
result = arr.to_numpy(na_value=na_value)
525525
expected = np.array(["a", na_value, "b"], dtype=object)
526526
tm.assert_numpy_array_equal(result, expected)
527+
528+
529+
def test_contains():
530+
# GH-xxxxx
531+
arr = pd.array(np.array(["a", "b"], dtype="string"))
532+
533+
assert "a" in arr
534+
assert "x" not in arr
535+
assert np.nan not in arr
536+
assert pd.NA not in arr
537+
538+
arr = pd.arrays.StringArray(np.array(["a", pd.NA]))
539+
assert "a" in arr
540+
assert "x" not in arr
541+
assert np.nan in arr
542+
assert pd.NA in arr

0 commit comments

Comments
 (0)