Skip to content

Commit 47d0da6

Browse files
authored
API: membership checks on ExtensionArray containing NA values (#37867)
1 parent 59710bc commit 47d0da6

File tree

7 files changed

+83
-0
lines changed

7 files changed

+83
-0
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ ExtensionArray
770770
- Fixed bug when applying a NumPy ufunc with multiple outputs to an :class:`.IntegerArray` returning None (:issue:`36913`)
771771
- Fixed an inconsistency in :class:`.PeriodArray`'s ``__init__`` signature to those of :class:`.DatetimeArray` and :class:`.TimedeltaArray` (:issue:`37289`)
772772
- Reductions for :class:`.BooleanArray`, :class:`.Categorical`, :class:`.DatetimeArray`, :class:`.FloatingArray`, :class:`.IntegerArray`, :class:`.PeriodArray`, :class:`.TimedeltaArray`, and :class:`.PandasArray` are now keyword-only methods (:issue:`37541`)
773+
- Fixed a bug where a ``TypeError`` was wrongly raised if a membership check was made on an ``ExtensionArray`` containing nan-like values (:issue:`37867`)
773774

774775
Other
775776
^^^^^

pandas/core/arrays/base.py

+18
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
is_array_like,
3838
is_dtype_equal,
3939
is_list_like,
40+
is_scalar,
4041
pandas_dtype,
4142
)
4243
from pandas.core.dtypes.dtypes import ExtensionDtype
@@ -354,6 +355,23 @@ def __iter__(self):
354355
for i in range(len(self)):
355356
yield self[i]
356357

358+
def __contains__(self, item) -> bool:
359+
"""
360+
Return for `item in self`.
361+
"""
362+
# GH37867
363+
# comparisons of any item to pd.NA always return pd.NA, so e.g. "a" in [pd.NA]
364+
# would raise a TypeError. The implementation below works around that.
365+
if is_scalar(item) and isna(item):
366+
if not self._can_hold_na:
367+
return False
368+
elif item is self.dtype.na_value or isinstance(item, self.dtype.type):
369+
return self.isna().any()
370+
else:
371+
return False
372+
else:
373+
return (item == self).any()
374+
357375
def __eq__(self, other: Any) -> ArrayLike:
358376
"""
359377
Return for `self == other` (element-wise equality).

pandas/tests/extension/arrow/test_bool.py

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def test_view(self, data):
5050
# __setitem__ does not work, so we only have a smoke-test
5151
data.view()
5252

53+
@pytest.mark.xfail(raises=AssertionError, reason="Not implemented yet")
54+
def test_contains(self, data, data_missing, nulls_fixture):
55+
super().test_contains(data, data_missing, nulls_fixture)
56+
5357

5458
class TestConstructors(BaseArrowTests, base.BaseConstructorsTests):
5559
def test_from_dtype(self, data):

pandas/tests/extension/base/interface.py

+23
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,29 @@ def test_can_hold_na_valid(self, data):
2929
# GH-20761
3030
assert data._can_hold_na is True
3131

32+
def test_contains(self, data, data_missing, nulls_fixture):
33+
# GH-37867
34+
# Tests for membership checks. Membership checks for nan-likes is tricky and
35+
# the settled on rule is: `nan_like in arr` is True if nan_like is
36+
# arr.dtype.na_value and arr.isna().any() is True. Else the check returns False.
37+
38+
na_value = data.dtype.na_value
39+
# ensure data without missing values
40+
data = data[~data.isna()]
41+
42+
# first elements are non-missing
43+
assert data[0] in data
44+
assert data_missing[0] in data_missing
45+
46+
# check the presence of na_value
47+
assert na_value in data_missing
48+
assert na_value not in data
49+
50+
if nulls_fixture is not na_value:
51+
# the data can never contain other nan-likes than na_value
52+
assert nulls_fixture not in data
53+
assert nulls_fixture not in data_missing
54+
3255
def test_memory_usage(self, data):
3356
s = pd.Series(data)
3457
result = s.memory_usage(index=False)

pandas/tests/extension/decimal/array.py

+8
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ def __setitem__(self, key, value):
155155
def __len__(self) -> int:
156156
return len(self._data)
157157

158+
def __contains__(self, item) -> bool:
159+
if not isinstance(item, decimal.Decimal):
160+
return False
161+
elif item.is_nan():
162+
return self.isna().any()
163+
else:
164+
return super().__contains__(item)
165+
158166
@property
159167
def nbytes(self) -> int:
160168
n = len(self)

pandas/tests/extension/json/test_json.py

+7
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ def test_custom_asserts(self):
143143
with pytest.raises(AssertionError, match=msg):
144144
self.assert_frame_equal(a.to_frame(), b.to_frame())
145145

146+
@pytest.mark.xfail(
147+
reason="comparison method not implemented for JSONArray (GH-37867)"
148+
)
149+
def test_contains(self, data):
150+
# GH-37867
151+
super().test_contains(data)
152+
146153

147154
class TestConstructors(BaseJSON, base.BaseConstructorsTests):
148155
@pytest.mark.skip(reason="not implemented constructor from dtype")

pandas/tests/extension/test_categorical.py

+22
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,28 @@ def test_memory_usage(self, data):
8787
# Is this deliberate?
8888
super().test_memory_usage(data)
8989

90+
def test_contains(self, data, data_missing, nulls_fixture):
91+
# GH-37867
92+
# na value handling in Categorical.__contains__ is deprecated.
93+
# See base.BaseInterFaceTests.test_contains for more details.
94+
95+
na_value = data.dtype.na_value
96+
# ensure data without missing values
97+
data = data[~data.isna()]
98+
99+
# first elements are non-missing
100+
assert data[0] in data
101+
assert data_missing[0] in data_missing
102+
103+
# check the presence of na_value
104+
assert na_value in data_missing
105+
assert na_value not in data
106+
107+
# Categoricals can contain other nan-likes than na_value
108+
if nulls_fixture is not na_value:
109+
assert nulls_fixture not in data
110+
assert nulls_fixture in data_missing # this line differs from super method
111+
90112

91113
class TestConstructors(base.BaseConstructorsTests):
92114
pass

0 commit comments

Comments
 (0)