Skip to content

Commit 36c8b88

Browse files
committed
ENH: .equals for Extension Arrays
1 parent f937843 commit 36c8b88

File tree

6 files changed

+115
-0
lines changed

6 files changed

+115
-0
lines changed

pandas/core/arrays/base.py

+50
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class ExtensionArray:
7878
dropna
7979
factorize
8080
fillna
81+
equals
8182
isna
8283
ravel
8384
repeat
@@ -350,6 +351,38 @@ def __iter__(self):
350351
for i in range(len(self)):
351352
yield self[i]
352353

354+
def __eq__(self, other: ABCExtensionArray) -> bool:
355+
"""
356+
Whether the two arrays are equivalent.
357+
358+
Parameters
359+
----------
360+
other: ExtensionArray
361+
The array to compare to this array.
362+
363+
Returns
364+
-------
365+
bool
366+
"""
367+
368+
raise AbstractMethodError(self)
369+
370+
def __ne__(self, other: ABCExtensionArray) -> bool:
371+
"""
372+
Whether the two arrays are not equivalent.
373+
374+
Parameters
375+
----------
376+
other: ExtensionArray
377+
The array to compare to this array.
378+
379+
Returns
380+
-------
381+
bool
382+
"""
383+
384+
raise AbstractMethodError(self)
385+
353386
# ------------------------------------------------------------------------
354387
# Required attributes
355388
# ------------------------------------------------------------------------
@@ -657,6 +690,23 @@ def searchsorted(self, value, side="left", sorter=None):
657690
arr = self.astype(object)
658691
return arr.searchsorted(value, side=side, sorter=sorter)
659692

693+
def equals(self, other: ABCExtensionArray) -> bool:
694+
"""
695+
Return if another array is equivalent to this array.
696+
697+
Parameters
698+
----------
699+
other: ExtensionArray
700+
Array to compare to this Array.
701+
702+
Returns
703+
-------
704+
boolean
705+
Whether the arrays are equivalent.
706+
707+
"""
708+
return ((self == other) | (self.isna() == other.isna())).all()
709+
660710
def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
661711
"""
662712
Return an array and missing value suitable for factorization.

pandas/core/arrays/boolean.py

+14
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,20 @@ def __getitem__(self, item):
327327

328328
return type(self)(self._data[item], self._mask[item])
329329

330+
def __eq__(self, other):
331+
if not isinstance(other, BooleanArray):
332+
return NotImplemented
333+
return (
334+
hasattr(other, "_data")
335+
and self._data == other._data
336+
and hasattr(other, "_mask")
337+
and self._mask == other._mask
338+
and hasattr(other, "_dtype") & self._dtype == other._dtype
339+
)
340+
341+
def __ne__(self, other):
342+
return not self.__eq__(other)
343+
330344
def _coerce_to_ndarray(self, dtype=None, na_value: "Scalar" = libmissing.NA):
331345
"""
332346
Coerce to an ndarray of object dtype or bool dtype (if force_bool=True).

pandas/core/arrays/categorical.py

+13
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,19 @@ def __setitem__(self, key, value):
20712071
lindexer = self._maybe_coerce_indexer(lindexer)
20722072
self._codes[key] = lindexer
20732073

2074+
def __eq__(self, other):
2075+
if not isinstance(other, Categorical):
2076+
return NotImplemented
2077+
return (
2078+
hasattr(other, "_codes")
2079+
and self._codes == other._codes
2080+
and hasattr(other, "_dtype")
2081+
and self._dtype == other._dtype
2082+
)
2083+
2084+
def __ne__(self, other):
2085+
return not self.__eq__(other)
2086+
20742087
def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]:
20752088
"""
20762089
Compute the inverse of a categorical, returning

pandas/core/arrays/integer.py

+13
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,19 @@ def __getitem__(self, item):
376376

377377
return type(self)(self._data[item], self._mask[item])
378378

379+
def __eq__(self, other):
380+
if not isinstance(other, IntegerArray):
381+
return NotImplemented
382+
return (
383+
hasattr(other, "_data")
384+
and self._data == other._data
385+
and hasattr(other, "_mask")
386+
and self._mask == other._mask
387+
)
388+
389+
def __ne__(self, other):
390+
return not self.__eq__(other)
391+
379392
def _coerce_to_ndarray(self, dtype=None, na_value=lib._no_default):
380393
"""
381394
coerce to an ndarary of object dtype

pandas/core/arrays/interval.py

+15
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,21 @@ def __setitem__(self, key, value):
547547
right.values[key] = value_right
548548
self._right = right
549549

550+
def __eq__(self, other):
551+
if not isinstance(other, IntervalArray):
552+
return NotImplementedError
553+
return (
554+
hasattr(other, "_left")
555+
and self._left == other._left
556+
and hasattr(other, "_right")
557+
and self._right == other._right
558+
and hasattr(other, "_closed")
559+
and self._closed == other._closed
560+
)
561+
562+
def __ne__(self, other):
563+
return not self.__eq__(other)
564+
550565
def fillna(self, value=None, method=None, limit=None):
551566
"""
552567
Fill NA/NaN values using the specified method.

pandas/tests/extension/json/array.py

+10
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@ def __setitem__(self, key, value):
110110
assert isinstance(v, self.dtype.type)
111111
self.data[k] = v
112112

113+
def __eq__(self, other):
114+
return (
115+
isinstance(other, JSONArray)
116+
and hasattr(other, "data")
117+
and self.data == other.data
118+
)
119+
120+
def __ne__(self, other):
121+
return not self.__eq__(other)
122+
113123
def __len__(self) -> int:
114124
return len(self.data)
115125

0 commit comments

Comments
 (0)