Skip to content

ENH: .equals for Extension Arrays #30652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 9, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
36c8b88
ENH: .equals for Extension Arrays
dwhu Jan 3, 2020
786963c
ENH: Updating eq and ne methods for extension arrays.
dwhu Jan 3, 2020
6800315
Removing interval.py's __eq__ implementation due to conflict with @js…
dwhu Jan 3, 2020
a3e7b7f
ENH: Making EA eq and ne typed as Any.
dwhu Jan 3, 2020
860013f
ENH: Adding default implementation to ExtensionArray equals() and tests.
dwhu Jan 3, 2020
fc3d2c2
Merge remote-tracking branch 'upstream/master' into gh-27081
dwhu Jan 9, 2020
3da5726
Merge remote-tracking branch 'upstream/master' into gh-27081
jorisvandenbossche May 1, 2020
c5027dd
correct __eq/ne__ to be element-wise
jorisvandenbossche May 1, 2020
375664c
fix equals implementation (& instead of ==)
jorisvandenbossche May 1, 2020
b6ad2fb
base tests
jorisvandenbossche May 1, 2020
365362a
ensure to dispatch Series.equals to EA.equals
jorisvandenbossche May 1, 2020
aae2f94
Merge remote-tracking branch 'upstream/master' into gh-27081
jorisvandenbossche May 2, 2020
8d052ad
feedback: docs, whatsnew, dataframe test, strict dtype test
jorisvandenbossche May 2, 2020
9ee034e
add to reference docs
jorisvandenbossche May 2, 2020
38501e6
remove IntervalArray.__ne__
jorisvandenbossche May 2, 2020
dccec7f
type ignore following mypy issue (mypy/2783)
jorisvandenbossche May 5, 2020
0b1255f
Merge remote-tracking branch 'upstream/master' into gh-27081
jorisvandenbossche May 7, 2020
b8be858
try again without type: ignore
jorisvandenbossche May 7, 2020
4c7273f
updates
jorisvandenbossche May 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/reference/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ objects.
api.extensions.ExtensionArray.copy
api.extensions.ExtensionArray.view
api.extensions.ExtensionArray.dropna
api.extensions.ExtensionArray.equals
api.extensions.ExtensionArray.factorize
api.extensions.ExtensionArray.fillna
api.extensions.ExtensionArray.isna
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Other enhancements
such as ``dict`` and ``list``, mirroring the behavior of :meth:`DataFrame.update` (:issue:`33215`)
- :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`)
- :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`)
- The ``ExtensionArray`` class has now an ``equals`` method, similarly to ``Series.equals()`` (:issue:`27081`).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can add the doc reference here

-

.. ---------------------------------------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion pandas/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,9 @@ def box_expected(expected, box_cls, transpose=True):
-------
subclass of box_cls
"""
if box_cls is pd.Index:
if box_cls is pd.array:
expected = pd.array(expected)
elif box_cls is pd.Index:
expected = pd.Index(expected)
elif box_cls is pd.Series:
expected = pd.Series(expected)
Expand Down
55 changes: 52 additions & 3 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ExtensionArray:
dropna
factorize
fillna
equals
isna
ravel
repeat
Expand All @@ -84,6 +85,7 @@ class ExtensionArray:
* _from_factorized
* __getitem__
* __len__
* __eq__
* dtype
* nbytes
* isna
Expand Down Expand Up @@ -333,6 +335,24 @@ def __iter__(self):
for i in range(len(self)):
yield self[i]

def __eq__(self, other: Any) -> ArrayLike: # type: ignore[override] # NOQA
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not familiar with the ignore[overrride] # NOQA pattern. do we use that elsewhere?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I posted the mypy issue, but apparently forgot.

So the problem here is that mypy expects a "bool" return for __eq__ since the base object is typed that way in the stubs. In python/mypy#2783, the recommended way to solve this is the above with # type: ignore[override].

But flake8 doesn't like that, hence also the #NOQA (PyCQA/pyflakes#475).

This was already done in another PR as well:

def sort_values( # type: ignore[override] # NOQA # issue 27237

"""
Return for `self == other` (element-wise equality).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have some guidance here that if other is a Series then you should return NotImplemented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could use ops.common.unpack_zerodim_and_defer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could use ops.common.unpack_zerodim_and_defer

This method is to be implemented by EA authors, so those can't use that helper (unless we expose somewhere a public version of this).

(we could of course use that for our own EAs, but this PR is not changing any existing __eq__ implementation at the moment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have some guidance here that if other is a Series then you should return NotImplemented.

Added a comment about that

"""
# Implementer note: this should return a boolean numpy ndarray or
# a boolean ExtensionArray.
# When `other` is one of Series, Index, or DataFrame, this method should
# return NotImplemented (to ensure that those objects are responsible for
# first unpacking the arrays, and then dispatch the operation to the
# underlying arrays)
raise AbstractMethodError(self)

def __ne__(self, other: Any) -> ArrayLike: # type: ignore[override] # NOQA
"""
Return for `self != other` (element-wise in-equality).
"""
return ~(self == other)

def to_numpy(
self, dtype=None, copy: bool = False, na_value=lib.no_default
) -> np.ndarray:
Expand Down Expand Up @@ -682,6 +702,35 @@ def searchsorted(self, value, side="left", sorter=None):
arr = self.astype(object)
return arr.searchsorted(value, side=side, sorter=sorter)

def equals(self, other: "ExtensionArray") -> bool:
"""
Return if another array is equivalent to this array.

Parameters
----------
other: ExtensionArray
Array to compare to this Array.

Returns
-------
boolean
Whether the arrays are equivalent.

"""
if not type(self) == type(other):
return False
Comment on lines +723 to +724
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should verify that this is the behavior we want. Namely

  1. other array-likes are not equivalent, even if they are all equal.
  2. subclasses are not equivalent, even if they are all equal.

The first seems fine. Not sure about the second.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was planning to open an issue about this after this PR (and keep it strict here), because this is right now a bit inconsistent within pandas, and might require a more general discussion / clean-up (eg Series.equals is more strict (requires same dtype) than Index.equals ...)
But we can certainly also have the discussion here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to be strict for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it even more strict (same dtype, not just same class), and added a test for that.
Will open an issue for the general discussion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened an issue for this at #33940

In the end, it seems mainly to come to whether the dtype should exactly be equal or not.

Since for EAs, the dtype is right now tied to the array class, using equal dtype for now also implies the same class (no additional check to allow sublcasses).

elif not self.dtype == other.dtype:
return False
elif not len(self) == len(other):
return False
else:
equal_values = self == other
if isinstance(equal_values, ExtensionArray):
# boolean array with NA -> fill with False
equal_values = equal_values.fillna(False)
equal_na = self.isna() & other.isna()
return (equal_values | equal_na).all().item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the .item() necessary?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the .item() necessary?

It's to have a python bool, instead of a numpy bool, as result.
I added a comment to the tests to make it explicit this is the reason we are asserting with is True/False (and later on we don't inadvertedly "clean" that up)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for clarifying. i was recently reminded that np.array(np.timedelta64(1234, "ns")).item() gives an int instead of timedelta64, so im now cautious around .item()


def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
"""
Return an array and missing value suitable for factorization.
Expand Down Expand Up @@ -1129,7 +1178,7 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin):
"""

@classmethod
def _create_method(cls, op, coerce_to_dtype=True):
def _create_method(cls, op, coerce_to_dtype=True, result_dtype=None):
"""
A class method that returns a method that will correspond to an
operator for an ExtensionArray subclass, by dispatching to the
Expand Down Expand Up @@ -1197,7 +1246,7 @@ def _maybe_convert(arr):
# exception raised in _from_sequence; ensure we have ndarray
res = np.asarray(arr)
else:
res = np.asarray(arr)
res = np.asarray(arr, dtype=result_dtype)
return res

if op.__name__ in {"divmod", "rdivmod"}:
Expand All @@ -1215,4 +1264,4 @@ def _create_arithmetic_method(cls, op):

@classmethod
def _create_comparison_method(cls, op):
return cls._create_method(op, coerce_to_dtype=False)
return cls._create_method(op, coerce_to_dtype=False, result_dtype=bool)
3 changes: 0 additions & 3 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,6 @@ def __eq__(self, other):

return result

def __ne__(self, other):
return ~self.__eq__(other)

def fillna(self, value=None, method=None, limit=None):
"""
Fill NA/NaN values using the specified method.
Expand Down
5 changes: 5 additions & 0 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,6 +1861,11 @@ def where(

return [self.make_block_same_class(result, placement=self.mgr_locs)]

def equals(self, other) -> bool:
if self.dtype != other.dtype or self.shape != other.shape:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this check here? (its already done on the .values no?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good question. I just copied that from the base Block.equals method (so it's done there as well). There, array_equivalent is used. Maybe that is less strict and the extra check is needed, which is not needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least from the docstring of array_equivalent (not sure how up to date that is), using that method indeed requires a more strict check in advance of calling it:

in corresponding locations. False otherwise. It is assumed that left and
right are NumPy arrays of the same dtype. The behavior of this function
(particularly with respect to NaNs) is not defined if the dtypes are
different.

But so, for EA.equals that is not needed, so will remove that check here.

return False
return self.values.equals(other.values)

def _unstack(self, unstacker, fill_value, new_placement):
# ExtensionArray-safe unstack.
# We override ObjectBlock._unstack, which unstacks directly on the
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/arrays/integer/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,13 @@ def test_compare_to_int(self, any_nullable_int_dtype, all_compare_operators):
expected[s2.isna()] = pd.NA

self.assert_series_equal(result, expected)


def test_equals():
# GH-30652
# equals is generally tested in /tests/extension/base/methods, but this
# specifically tests that two arrays of the same class but different dtype
# do not evaluate equal
a1 = pd.array([1, 2, None], dtype="Int64")
a2 = pd.array([1, 2, None], dtype="Int32")
assert a1.equals(a2) is False
29 changes: 29 additions & 0 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,32 @@ def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy):
np.repeat(data, repeats, **kwargs)
else:
data.repeat(repeats, **kwargs)

@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
def test_equals(self, data, na_value, as_series, box):
data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype)
data_na = type(data)._from_sequence([na_value] * len(data), dtype=data.dtype)

data = tm.box_expected(data, box, transpose=False)
data2 = tm.box_expected(data2, box, transpose=False)
data_na = tm.box_expected(data_na, box, transpose=False)

# we are asserting with `is True/False` explicitly, to test that the
# result is an actual Python bool, and not something "truthy"

assert data.equals(data) is True
assert data.equals(data.copy()) is True

# unequal other data
assert data.equals(data2) is False
assert data.equals(data_na) is False

# different length
assert data[:2].equals(data[:3]) is False

# emtpy are equal
assert data[:0].equals(data[:0]) is True

# other types
assert data.equals(None) is False
assert data[[0]].equals(data[0]) is False
8 changes: 6 additions & 2 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,8 @@ class BaseComparisonOpsTests(BaseOpsUtil):
def _compare_other(self, s, data, op_name, other):
op = self.get_op_from_name(op_name)
if op_name == "__eq__":
assert getattr(data, op_name)(other) is NotImplemented
assert not op(s, other).all()
elif op_name == "__ne__":
assert getattr(data, op_name)(other) is NotImplemented
assert op(s, other).all()

else:
Expand Down Expand Up @@ -176,6 +174,12 @@ def test_direct_arith_with_series_returns_not_implemented(self, data):
else:
raise pytest.skip(f"{type(data).__name__} does not implement __eq__")

if hasattr(data, "__ne__"):
result = data.__ne__(other)
assert result is NotImplemented
else:
raise pytest.skip(f"{type(data).__name__} does not implement __ne__")


class BaseUnaryOpsTests(BaseOpsUtil):
def test_invert(self, data):
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/extension/json/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def __setitem__(self, key, value):
def __len__(self) -> int:
return len(self.data)

def __eq__(self, other):
return NotImplemented

def __ne__(self, other):
return NotImplemented

def __array__(self, dtype=None):
if dtype is None:
dtype = object
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def test_where_series(self, data, na_value):
def test_searchsorted(self, data_for_sorting):
super().test_searchsorted(data_for_sorting)

@pytest.mark.skip(reason="Can't compare dicts.")
def test_equals(self, data, na_value, as_series):
pass


class TestCasting(BaseJSON, base.BaseCastingTests):
@pytest.mark.skip(reason="failing on np.array(self, dtype=str)")
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ def test_repeat(self, data, repeats, as_series, use_numpy):
def test_diff(self, data, periods):
return super().test_diff(data, periods)

@skip_nested
@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
def test_equals(self, data, na_value, as_series, box):
# Fails creating with _from_sequence
super().test_equals(data, na_value, as_series, box)


@skip_nested
class TestArithmetics(BaseNumPyTests, base.BaseArithmeticOpsTests):
Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ def test_shift_0_periods(self, data):
data._sparse_values[0] = data._sparse_values[1]
assert result._sparse_values[0] != result._sparse_values[1]

@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
def test_equals(self, data, na_value, as_series, box):
self._check_unsupported(data)
super().test_equals(data, na_value, as_series, box)


class TestCasting(BaseSparseTests, base.BaseCastingTests):
def test_astype_object_series(self, all_data):
Expand Down