diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 6f046d3a9379d..825c5367a8b37 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -752,6 +752,7 @@ ExtensionArray - Fixed bug when applying a NumPy ufunc with multiple outputs to an :class:`.IntegerArray` returning None (:issue:`36913`) - Fixed an inconsistency in :class:`.PeriodArray`'s ``__init__`` signature to those of :class:`.DatetimeArray` and :class:`.TimedeltaArray` (:issue:`37289`) - Reductions for :class:`.BooleanArray`, :class:`.Categorical`, :class:`.DatetimeArray`, :class:`.FloatingArray`, :class:`.IntegerArray`, :class:`.PeriodArray`, :class:`.TimedeltaArray`, and :class:`.PandasArray` are now keyword-only methods (:issue:`37541`) +- Fixed a bug where a ``TypeError`` was wrongly raised if a membership check was made on an ``ExtensionArray`` containing nan-like values (:issue:`37867`) Other ^^^^^ diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 448025e05422d..76b7877b0ac70 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -37,6 +37,7 @@ is_array_like, is_dtype_equal, is_list_like, + is_scalar, pandas_dtype, ) from pandas.core.dtypes.dtypes import ExtensionDtype @@ -354,6 +355,23 @@ def __iter__(self): for i in range(len(self)): yield self[i] + def __contains__(self, item) -> bool: + """ + Return for `item in self`. + """ + # GH37867 + # comparisons of any item to pd.NA always return pd.NA, so e.g. "a" in [pd.NA] + # would raise a TypeError. The implementation below works around that. + if is_scalar(item) and isna(item): + if not self._can_hold_na: + return False + elif item is self.dtype.na_value or isinstance(item, self.dtype.type): + return self.isna().any() + else: + return False + else: + return (item == self).any() + def __eq__(self, other: Any) -> ArrayLike: """ Return for `self == other` (element-wise equality). diff --git a/pandas/tests/extension/arrow/test_bool.py b/pandas/tests/extension/arrow/test_bool.py index 12426a0c92c55..922b3b94c16c1 100644 --- a/pandas/tests/extension/arrow/test_bool.py +++ b/pandas/tests/extension/arrow/test_bool.py @@ -50,6 +50,10 @@ def test_view(self, data): # __setitem__ does not work, so we only have a smoke-test data.view() + @pytest.mark.xfail(raises=AssertionError, reason="Not implemented yet") + def test_contains(self, data, data_missing, nulls_fixture): + super().test_contains(data, data_missing, nulls_fixture) + class TestConstructors(BaseArrowTests, base.BaseConstructorsTests): def test_from_dtype(self, data): diff --git a/pandas/tests/extension/base/interface.py b/pandas/tests/extension/base/interface.py index 9ae4b01508d79..d7997310dde3d 100644 --- a/pandas/tests/extension/base/interface.py +++ b/pandas/tests/extension/base/interface.py @@ -29,6 +29,29 @@ def test_can_hold_na_valid(self, data): # GH-20761 assert data._can_hold_na is True + def test_contains(self, data, data_missing, nulls_fixture): + # GH-37867 + # Tests for membership checks. Membership checks for nan-likes is tricky and + # the settled on rule is: `nan_like in arr` is True if nan_like is + # arr.dtype.na_value and arr.isna().any() is True. Else the check returns False. + + na_value = data.dtype.na_value + # ensure data without missing values + data = data[~data.isna()] + + # first elements are non-missing + assert data[0] in data + assert data_missing[0] in data_missing + + # check the presence of na_value + assert na_value in data_missing + assert na_value not in data + + if nulls_fixture is not na_value: + # the data can never contain other nan-likes than na_value + assert nulls_fixture not in data + assert nulls_fixture not in data_missing + def test_memory_usage(self, data): s = pd.Series(data) result = s.memory_usage(index=False) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 9ede9c7fbd0fd..a713550dafa5c 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -155,6 +155,14 @@ def __setitem__(self, key, value): def __len__(self) -> int: return len(self._data) + def __contains__(self, item) -> bool: + if not isinstance(item, decimal.Decimal): + return False + elif item.is_nan(): + return self.isna().any() + else: + return super().__contains__(item) + @property def nbytes(self) -> int: n = len(self) diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index 74ca341e27bf8..3a5e49796c53b 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -143,6 +143,13 @@ def test_custom_asserts(self): with pytest.raises(AssertionError, match=msg): self.assert_frame_equal(a.to_frame(), b.to_frame()) + @pytest.mark.xfail( + reason="comparison method not implemented for JSONArray (GH-37867)" + ) + def test_contains(self, data): + # GH-37867 + super().test_contains(data) + class TestConstructors(BaseJSON, base.BaseConstructorsTests): @pytest.mark.skip(reason="not implemented constructor from dtype") diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index 95f338cbc3240..d03a9ab6b2588 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -87,6 +87,28 @@ def test_memory_usage(self, data): # Is this deliberate? super().test_memory_usage(data) + def test_contains(self, data, data_missing, nulls_fixture): + # GH-37867 + # na value handling in Categorical.__contains__ is deprecated. + # See base.BaseInterFaceTests.test_contains for more details. + + na_value = data.dtype.na_value + # ensure data without missing values + data = data[~data.isna()] + + # first elements are non-missing + assert data[0] in data + assert data_missing[0] in data_missing + + # check the presence of na_value + assert na_value in data_missing + assert na_value not in data + + # Categoricals can contain other nan-likes than na_value + if nulls_fixture is not na_value: + assert nulls_fixture not in data + assert nulls_fixture in data_missing # this line differs from super method + class TestConstructors(base.BaseConstructorsTests): pass