Skip to content

Commit f21bc99

Browse files
ENH: .equals for Extension Arrays (#30652)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent 6388370 commit f21bc99

File tree

13 files changed

+130
-9
lines changed

13 files changed

+130
-9
lines changed

doc/source/reference/extensions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ objects.
4545
api.extensions.ExtensionArray.copy
4646
api.extensions.ExtensionArray.view
4747
api.extensions.ExtensionArray.dropna
48+
api.extensions.ExtensionArray.equals
4849
api.extensions.ExtensionArray.factorize
4950
api.extensions.ExtensionArray.fillna
5051
api.extensions.ExtensionArray.isna

doc/source/whatsnew/v1.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ Other enhancements
150150
such as ``dict`` and ``list``, mirroring the behavior of :meth:`DataFrame.update` (:issue:`33215`)
151151
- :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`)
152152
- :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`)
153+
- The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals`
154+
method, similarly to :meth:`Series.equals` (:issue:`27081`).
153155
-
154156

155157
.. ---------------------------------------------------------------------------

pandas/_testing.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1490,7 +1490,9 @@ def box_expected(expected, box_cls, transpose=True):
14901490
-------
14911491
subclass of box_cls
14921492
"""
1493-
if box_cls is pd.Index:
1493+
if box_cls is pd.array:
1494+
expected = pd.array(expected)
1495+
elif box_cls is pd.Index:
14941496
expected = pd.Index(expected)
14951497
elif box_cls is pd.Series:
14961498
expected = pd.Series(expected)

pandas/core/arrays/base.py

+55-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ExtensionArray:
5858
dropna
5959
factorize
6060
fillna
61+
equals
6162
isna
6263
ravel
6364
repeat
@@ -84,6 +85,7 @@ class ExtensionArray:
8485
* _from_factorized
8586
* __getitem__
8687
* __len__
88+
* __eq__
8789
* dtype
8890
* nbytes
8991
* isna
@@ -333,6 +335,24 @@ def __iter__(self):
333335
for i in range(len(self)):
334336
yield self[i]
335337

338+
def __eq__(self, other: Any) -> ArrayLike:
339+
"""
340+
Return for `self == other` (element-wise equality).
341+
"""
342+
# Implementer note: this should return a boolean numpy ndarray or
343+
# a boolean ExtensionArray.
344+
# When `other` is one of Series, Index, or DataFrame, this method should
345+
# return NotImplemented (to ensure that those objects are responsible for
346+
# first unpacking the arrays, and then dispatch the operation to the
347+
# underlying arrays)
348+
raise AbstractMethodError(self)
349+
350+
def __ne__(self, other: Any) -> ArrayLike:
351+
"""
352+
Return for `self != other` (element-wise in-equality).
353+
"""
354+
return ~(self == other)
355+
336356
def to_numpy(
337357
self, dtype=None, copy: bool = False, na_value=lib.no_default
338358
) -> np.ndarray:
@@ -682,6 +702,38 @@ def searchsorted(self, value, side="left", sorter=None):
682702
arr = self.astype(object)
683703
return arr.searchsorted(value, side=side, sorter=sorter)
684704

705+
def equals(self, other: "ExtensionArray") -> bool:
706+
"""
707+
Return if another array is equivalent to this array.
708+
709+
Equivalent means that both arrays have the same shape and dtype, and
710+
all values compare equal. Missing values in the same location are
711+
considered equal (in contrast with normal equality).
712+
713+
Parameters
714+
----------
715+
other : ExtensionArray
716+
Array to compare to this Array.
717+
718+
Returns
719+
-------
720+
boolean
721+
Whether the arrays are equivalent.
722+
"""
723+
if not type(self) == type(other):
724+
return False
725+
elif not self.dtype == other.dtype:
726+
return False
727+
elif not len(self) == len(other):
728+
return False
729+
else:
730+
equal_values = self == other
731+
if isinstance(equal_values, ExtensionArray):
732+
# boolean array with NA -> fill with False
733+
equal_values = equal_values.fillna(False)
734+
equal_na = self.isna() & other.isna()
735+
return (equal_values | equal_na).all().item()
736+
685737
def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
686738
"""
687739
Return an array and missing value suitable for factorization.
@@ -1134,7 +1186,7 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin):
11341186
"""
11351187

11361188
@classmethod
1137-
def _create_method(cls, op, coerce_to_dtype=True):
1189+
def _create_method(cls, op, coerce_to_dtype=True, result_dtype=None):
11381190
"""
11391191
A class method that returns a method that will correspond to an
11401192
operator for an ExtensionArray subclass, by dispatching to the
@@ -1202,7 +1254,7 @@ def _maybe_convert(arr):
12021254
# exception raised in _from_sequence; ensure we have ndarray
12031255
res = np.asarray(arr)
12041256
else:
1205-
res = np.asarray(arr)
1257+
res = np.asarray(arr, dtype=result_dtype)
12061258
return res
12071259

12081260
if op.__name__ in {"divmod", "rdivmod"}:
@@ -1220,4 +1272,4 @@ def _create_arithmetic_method(cls, op):
12201272

12211273
@classmethod
12221274
def _create_comparison_method(cls, op):
1223-
return cls._create_method(op, coerce_to_dtype=False)
1275+
return cls._create_method(op, coerce_to_dtype=False, result_dtype=bool)

pandas/core/arrays/interval.py

-3
Original file line numberDiff line numberDiff line change
@@ -606,9 +606,6 @@ def __eq__(self, other):
606606

607607
return result
608608

609-
def __ne__(self, other):
610-
return ~self.__eq__(other)
611-
612609
def fillna(self, value=None, method=None, limit=None):
613610
"""
614611
Fill NA/NaN values using the specified method.

pandas/core/internals/blocks.py

+3
Original file line numberDiff line numberDiff line change
@@ -1864,6 +1864,9 @@ def where(
18641864

18651865
return [self.make_block_same_class(result, placement=self.mgr_locs)]
18661866

1867+
def equals(self, other) -> bool:
1868+
return self.values.equals(other.values)
1869+
18671870
def _unstack(self, unstacker, fill_value, new_placement):
18681871
# ExtensionArray-safe unstack.
18691872
# We override ObjectBlock._unstack, which unstacks directly on the

pandas/tests/arrays/integer/test_comparison.py

+10
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,13 @@ def test_compare_to_int(self, any_nullable_int_dtype, all_compare_operators):
104104
expected[s2.isna()] = pd.NA
105105

106106
self.assert_series_equal(result, expected)
107+
108+
109+
def test_equals():
110+
# GH-30652
111+
# equals is generally tested in /tests/extension/base/methods, but this
112+
# specifically tests that two arrays of the same class but different dtype
113+
# do not evaluate equal
114+
a1 = pd.array([1, 2, None], dtype="Int64")
115+
a2 = pd.array([1, 2, None], dtype="Int32")
116+
assert a1.equals(a2) is False

pandas/tests/extension/base/methods.py

+29
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,32 @@ def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy):
421421
np.repeat(data, repeats, **kwargs)
422422
else:
423423
data.repeat(repeats, **kwargs)
424+
425+
@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
426+
def test_equals(self, data, na_value, as_series, box):
427+
data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype)
428+
data_na = type(data)._from_sequence([na_value] * len(data), dtype=data.dtype)
429+
430+
data = tm.box_expected(data, box, transpose=False)
431+
data2 = tm.box_expected(data2, box, transpose=False)
432+
data_na = tm.box_expected(data_na, box, transpose=False)
433+
434+
# we are asserting with `is True/False` explicitly, to test that the
435+
# result is an actual Python bool, and not something "truthy"
436+
437+
assert data.equals(data) is True
438+
assert data.equals(data.copy()) is True
439+
440+
# unequal other data
441+
assert data.equals(data2) is False
442+
assert data.equals(data_na) is False
443+
444+
# different length
445+
assert data[:2].equals(data[:3]) is False
446+
447+
# emtpy are equal
448+
assert data[:0].equals(data[:0]) is True
449+
450+
# other types
451+
assert data.equals(None) is False
452+
assert data[[0]].equals(data[0]) is False

pandas/tests/extension/base/ops.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,8 @@ class BaseComparisonOpsTests(BaseOpsUtil):
139139
def _compare_other(self, s, data, op_name, other):
140140
op = self.get_op_from_name(op_name)
141141
if op_name == "__eq__":
142-
assert getattr(data, op_name)(other) is NotImplemented
143142
assert not op(s, other).all()
144143
elif op_name == "__ne__":
145-
assert getattr(data, op_name)(other) is NotImplemented
146144
assert op(s, other).all()
147145

148146
else:
@@ -176,6 +174,12 @@ def test_direct_arith_with_series_returns_not_implemented(self, data):
176174
else:
177175
raise pytest.skip(f"{type(data).__name__} does not implement __eq__")
178176

177+
if hasattr(data, "__ne__"):
178+
result = data.__ne__(other)
179+
assert result is NotImplemented
180+
else:
181+
raise pytest.skip(f"{type(data).__name__} does not implement __ne__")
182+
179183

180184
class BaseUnaryOpsTests(BaseOpsUtil):
181185
def test_invert(self, data):

pandas/tests/extension/json/array.py

+6
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def __setitem__(self, key, value):
105105
def __len__(self) -> int:
106106
return len(self.data)
107107

108+
def __eq__(self, other):
109+
return NotImplemented
110+
111+
def __ne__(self, other):
112+
return NotImplemented
113+
108114
def __array__(self, dtype=None):
109115
if dtype is None:
110116
dtype = object

pandas/tests/extension/json/test_json.py

+4
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ def test_where_series(self, data, na_value):
262262
def test_searchsorted(self, data_for_sorting):
263263
super().test_searchsorted(data_for_sorting)
264264

265+
@pytest.mark.skip(reason="Can't compare dicts.")
266+
def test_equals(self, data, na_value, as_series):
267+
pass
268+
265269

266270
class TestCasting(BaseJSON, base.BaseCastingTests):
267271
@pytest.mark.skip(reason="failing on np.array(self, dtype=str)")

pandas/tests/extension/test_numpy.py

+6
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,12 @@ def test_repeat(self, data, repeats, as_series, use_numpy):
276276
def test_diff(self, data, periods):
277277
return super().test_diff(data, periods)
278278

279+
@skip_nested
280+
@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
281+
def test_equals(self, data, na_value, as_series, box):
282+
# Fails creating with _from_sequence
283+
super().test_equals(data, na_value, as_series, box)
284+
279285

280286
@skip_nested
281287
class TestArithmetics(BaseNumPyTests, base.BaseArithmeticOpsTests):

pandas/tests/extension/test_sparse.py

+5
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@ def test_shift_0_periods(self, data):
316316
data._sparse_values[0] = data._sparse_values[1]
317317
assert result._sparse_values[0] != result._sparse_values[1]
318318

319+
@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
320+
def test_equals(self, data, na_value, as_series, box):
321+
self._check_unsupported(data)
322+
super().test_equals(data, na_value, as_series, box)
323+
319324

320325
class TestCasting(BaseSparseTests, base.BaseCastingTests):
321326
def test_astype_object_series(self, all_data):

0 commit comments

Comments
 (0)