Skip to content

Commit f9c7ed5

Browse files
committed
stricter na_value requirements
1 parent e780f3e commit f9c7ed5

File tree

4 files changed

+28
-3
lines changed

4 files changed

+28
-3
lines changed

pandas/core/arrays/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
is_array_like,
3838
is_dtype_equal,
3939
is_list_like,
40+
is_scalar,
4041
pandas_dtype,
4142
)
4243
from pandas.core.dtypes.dtypes import ExtensionDtype
@@ -360,8 +361,10 @@ def __contains__(self, item) -> bool:
360361
"""
361362
# comparisons of any item to pd.NA always return pd.NA, so e.g. "a" in [pd.NA]
362363
# would raise a TypeError. The implementation below works around that.
363-
if is_valid_nat_for_dtype(item, self.dtype):
364+
if item is self.dtype.na_value:
364365
return isna(self).any() if self._can_hold_na else False
366+
elif is_scalar(item) and isna(item):
367+
return False
365368
else:
366369
return (item == self).any()
367370

pandas/tests/arrays/string_/test_string.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def test_to_numpy_na_value(dtype, nulls_fixture):
527527

528528

529529
def test_contains():
530-
# GH-xxxxx
530+
# GH-37867
531531
arr = pd.array(["a", "b"], dtype="string")
532532

533533
assert "a" in arr
@@ -538,5 +538,5 @@ def test_contains():
538538
arr = pd.array(["a", pd.NA], dtype="string")
539539
assert "a" in arr
540540
assert "x" not in arr
541-
assert np.nan in arr
541+
assert np.nan not in arr
542542
assert pd.NA in arr

pandas/tests/extension/base/interface.py

+17
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@ def test_can_hold_na_valid(self, data):
2929
# GH-20761
3030
assert data._can_hold_na is True
3131

32+
def test_contains(self, data):
33+
# GH-37867
34+
scalar = data[~data.isna()][0]
35+
36+
assert scalar in data
37+
38+
na_value = data.dtype.na_value
39+
other_na_value_types = {np.nan, pd.NA, pd.NaT}.difference({na_value})
40+
if data.isna().any():
41+
assert na_value in data
42+
for na_value_type in other_na_value_types:
43+
assert na_value_type not in data
44+
else:
45+
assert na_value not in data
46+
for na_value_type in other_na_value_types:
47+
assert na_value_type not in data
48+
3249
def test_memory_usage(self, data):
3350
s = pd.Series(data)
3451
result = s.memory_usage(index=False)

pandas/tests/extension/json/test_json.py

+5
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def test_custom_asserts(self):
143143
with pytest.raises(AssertionError, match=msg):
144144
self.assert_frame_equal(a.to_frame(), b.to_frame())
145145

146+
@pytest.mark.xfail(reason="comparison method not implemented on JSONArray")
147+
def test_contains(self, data):
148+
# GH-37867
149+
super().test_contains(data)
150+
146151

147152
class TestConstructors(BaseJSON, base.BaseConstructorsTests):
148153
@pytest.mark.skip(reason="not implemented constructor from dtype")

0 commit comments

Comments
 (0)