Skip to content

API: membership checks on ExtensionArray containing NA values #37867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Nov 29, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ ExtensionArray
- Fixed bug when applying a NumPy ufunc with multiple outputs to an :class:`.IntegerArray` returning None (:issue:`36913`)
- Fixed an inconsistency in :class:`.PeriodArray`'s ``__init__`` signature to those of :class:`.DatetimeArray` and :class:`.TimedeltaArray` (:issue:`37289`)
- Reductions for :class:`.BooleanArray`, :class:`.Categorical`, :class:`.DatetimeArray`, :class:`.FloatingArray`, :class:`.IntegerArray`, :class:`.PeriodArray`, :class:`.TimedeltaArray`, and :class:`.PandasArray` are now keyword-only methods (:issue:`37541`)
- Fixed a bug where a ``ValueError`` was wrongly raised if a membership check was made on an ``ExtensionArray`` containing nan-like values (:issue:`37867`)

Other
^^^^^
Expand Down
14 changes: 14 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
is_array_like,
is_dtype_equal,
is_list_like,
is_scalar,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
Expand Down Expand Up @@ -354,6 +355,19 @@ def __iter__(self):
for i in range(len(self)):
yield self[i]

def __contains__(self, item) -> bool:
"""
Return for `item in self`.
"""
# comparisons of any item to pd.NA always return pd.NA, so e.g. "a" in [pd.NA]
# would raise a TypeError. The implementation below works around that.
if item is self.dtype.na_value:
return isna(self).any() if self._can_hold_na else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use self.isna()

elif is_scalar(item) and isna(item):
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could maybe be done in a separate PR, as it is FloatingArray specific (and the NaN behaviour is not yet fully fleshed out), but this check will do the wrong thing for np.nan if the FloatingArray actually contains a NaN (as self.dtype._na_value is then pd.NA and not np.nan).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a FloatingArray can have multiple nan types? Is that really necessary?

I think I would prefer that would be handled in FloatingArray, so I don't have to stretch this PR too far.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a FloatingArray can have multiple nan types? Is that really necessary?

Not multiple "nan types", to be nitpicky, since pd.NA is not a "nan". But yes, it can contains both pd.NA and np.nan at the moment. But that's not yet fully decided though, long discussion at #32265. Happy to hear your thoughts about it.

I think I would prefer that would be handled in FloatingArray, so I don't have to stretch this PR too far.

Yes, will put it on my list of follow to-do items for FloatingArray that I need to open. So don't worry about it here

else:
return (item == self).any()

def __eq__(self, other: Any) -> ArrayLike:
"""
Return for `self == other` (element-wise equality).
Expand Down
20 changes: 20 additions & 0 deletions pandas/tests/arrays/categorical/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,23 @@ def test_numeric_like_ops(self):
msg = "Object with dtype category cannot perform the numpy op log"
with pytest.raises(TypeError, match=msg):
np.log(s)

def test_contains(self, ordered):
# GH-37867
cat = Categorical(["a", "b"], ordered=ordered)
assert "a" in cat
assert "x" not in cat
assert np.nan not in cat
assert pd.NA not in cat

cat = Categorical([np.nan, "a"], ordered=ordered)
assert "a" in cat
assert "x" not in cat
assert np.nan in cat
assert pd.NA in cat

cat = cat[::-1]
assert "a" in cat
assert "x" not in cat
assert np.nan in cat
assert pd.NA in cat
16 changes: 16 additions & 0 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,19 @@ def test_to_numpy_na_value(dtype, nulls_fixture):
result = arr.to_numpy(na_value=na_value)
expected = np.array(["a", na_value, "b"], dtype=object)
tm.assert_numpy_array_equal(result, expected)


def test_contains():
# GH-37867
arr = pd.array(["a", "b"], dtype="string")

assert "a" in arr
assert "x" not in arr
assert np.nan not in arr
assert pd.NA not in arr

arr = pd.array(["a", pd.NA], dtype="string")
assert "a" in arr
assert "x" not in arr
assert np.nan not in arr
assert pd.NA in arr
37 changes: 37 additions & 0 deletions pandas/tests/extension/base/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,43 @@ def test_can_hold_na_valid(self, data):
# GH-20761
assert data._can_hold_na is True

def test_contains(self, data):
# GH-37867

data = data[~data.isna()]

scalar = data[0]

assert scalar in data
assert "124jhujbhjhb5" not in data

na_value = data.dtype.na_value

assert na_value not in data

for na_value_type in {None, np.nan, pd.NA, pd.NaT}:
assert na_value_type not in data

def test_contains_nan(self, data_missing):
# GH-37867
data = data_missing

scalar = data[~data.isna()][0]

assert scalar in data

na_value = data.dtype.na_value

if data.isna().any():
assert na_value in data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data is not guaranteed to have a missing values, but there is also a data_missing fixture that has this guarantee and which you can add here to the test, and use that here.
And then to be sure we test the case of data without NAs, maybe do the else with data[~data.isna()] (that is guaranteed to not be empty)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And thanks for this base test!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've added a test method using data_missing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, you can pass both fixtures to the same test, then you can write the test with less duplication as it is now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can pass them as the string name of the test and use: frame = request.getfixturevalue(fixture_func_name) to create the actual frame

else:
assert na_value not in data

for na_value_type in {None, np.nan, pd.NA, pd.NaT}:
if na_value_type is na_value:
continue
assert na_value_type not in data

def test_memory_usage(self, data):
s = pd.Series(data)
result = s.memory_usage(index=False)
Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def test_custom_asserts(self):
with pytest.raises(AssertionError, match=msg):
self.assert_frame_equal(a.to_frame(), b.to_frame())

@pytest.mark.xfail(reason="comparison method not implemented on JSONArray")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is than issue number for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue number is the first line in the method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the xfail pls

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

def test_contains(self, data):
# GH-37867
super().test_contains(data)


class TestConstructors(BaseJSON, base.BaseConstructorsTests):
@pytest.mark.skip(reason="not implemented constructor from dtype")
Expand Down