Skip to content

Commit 1919814

Browse files
committed
[WIP] ENH: .equals for Extension Arrays
1 parent b804372 commit 1919814

File tree

6 files changed

+115
-0
lines changed

6 files changed

+115
-0
lines changed

pandas/core/arrays/base.py

+50
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class ExtensionArray:
7878
dropna
7979
factorize
8080
fillna
81+
equals
8182
isna
8283
ravel
8384
repeat
@@ -350,6 +351,38 @@ def __iter__(self):
350351
for i in range(len(self)):
351352
yield self[i]
352353

354+
def __eq__(self, other: ABCExtensionArray) -> bool:
355+
"""
356+
Whether the two arrays are equivalent.
357+
358+
Parameters
359+
----------
360+
other: ExtensionArray
361+
The array to compare to this array.
362+
363+
Returns
364+
-------
365+
bool
366+
"""
367+
368+
raise AbstractMethodError(self)
369+
370+
def __ne__(self, other: ABCExtensionArray) -> bool:
371+
"""
372+
Whether the two arrays are not equivalent.
373+
374+
Parameters
375+
----------
376+
other: ExtensionArray
377+
The array to compare to this array.
378+
379+
Returns
380+
-------
381+
bool
382+
"""
383+
384+
raise AbstractMethodError(self)
385+
353386
# ------------------------------------------------------------------------
354387
# Required attributes
355388
# ------------------------------------------------------------------------
@@ -657,6 +690,23 @@ def searchsorted(self, value, side="left", sorter=None):
657690
arr = self.astype(object)
658691
return arr.searchsorted(value, side=side, sorter=sorter)
659692

693+
def equals(self, other: ABCExtensionArray) -> bool:
694+
"""
695+
Return if another array is equivalent to this array.
696+
697+
Parameters
698+
----------
699+
other: ExtensionArray
700+
Array to compare to this Array.
701+
702+
Returns
703+
-------
704+
boolean
705+
Whether the arrays are equivalent.
706+
707+
"""
708+
return ((self == other) | (self.isna() == other.isna())).all()
709+
660710
def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
661711
"""
662712
Return an array and missing value suitable for factorization.

pandas/core/arrays/boolean.py

+14
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,20 @@ def __getitem__(self, item):
314314
return self._data[item]
315315
return type(self)(self._data[item], self._mask[item])
316316

317+
def __eq__(self, other):
318+
if not isinstance(other, BooleanArray):
319+
return NotImplemented
320+
return (
321+
hasattr(other, "_data")
322+
and self._data == other._data
323+
and hasattr(other, "_mask")
324+
and self._mask == other._mask
325+
and hasattr(other, "_dtype") & self._dtype == other._dtype
326+
)
327+
328+
def __ne__(self, other):
329+
return not self.__eq__(other)
330+
317331
def _coerce_to_ndarray(self, dtype=None, na_value: "Scalar" = libmissing.NA):
318332
"""
319333
Coerce to an ndarray of object dtype or bool dtype (if force_bool=True).

pandas/core/arrays/categorical.py

+13
Original file line numberDiff line numberDiff line change
@@ -2067,6 +2067,19 @@ def __setitem__(self, key, value):
20672067
lindexer = self._maybe_coerce_indexer(lindexer)
20682068
self._codes[key] = lindexer
20692069

2070+
def __eq__(self, other):
2071+
if not isinstance(other, Categorical):
2072+
return NotImplemented
2073+
return (
2074+
hasattr(other, "_codes")
2075+
and self._codes == other._codes
2076+
and hasattr(other, "_dtype")
2077+
and self._dtype == other._dtype
2078+
)
2079+
2080+
def __ne__(self, other):
2081+
return not self.__eq__(other)
2082+
20702083
def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]:
20712084
"""
20722085
Compute the inverse of a categorical, returning

pandas/core/arrays/integer.py

+13
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,19 @@ def __getitem__(self, item):
370370
return self._data[item]
371371
return type(self)(self._data[item], self._mask[item])
372372

373+
def __eq__(self, other):
374+
if not isinstance(other, IntegerArray):
375+
return NotImplemented
376+
return (
377+
hasattr(other, "_data")
378+
and self._data == other._data
379+
and hasattr(other, "_mask")
380+
and self._mask == other._mask
381+
)
382+
383+
def __ne__(self, other):
384+
return not self.__eq__(other)
385+
373386
def _coerce_to_ndarray(self, dtype=None, na_value=lib._no_default):
374387
"""
375388
coerce to an ndarary of object dtype

pandas/core/arrays/interval.py

+15
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,21 @@ def __setitem__(self, key, value):
547547
right.values[key] = value_right
548548
self._right = right
549549

550+
def __eq__(self, other):
551+
if not isinstance(other, IntervalArray):
552+
return NotImplementedError
553+
return (
554+
hasattr(other, "_left")
555+
and self._left == other._left
556+
and hasattr(other, "_right")
557+
and self._right == other._right
558+
and hasattr(other, "_closed")
559+
and self._closed == other._closed
560+
)
561+
562+
def __ne__(self, other):
563+
return not self.__eq__(other)
564+
550565
def fillna(self, value=None, method=None, limit=None):
551566
"""
552567
Fill NA/NaN values using the specified method.

pandas/tests/extension/json/array.py

+10
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ def __setitem__(self, key, value):
107107
assert isinstance(v, self.dtype.type)
108108
self.data[k] = v
109109

110+
def __eq__(self, other):
111+
return (
112+
isinstance(other, JSONArray)
113+
and hasattr(other, "data")
114+
and self.data == other.data
115+
)
116+
117+
def __ne__(self, other):
118+
return not self.__eq__(other)
119+
110120
def __len__(self) -> int:
111121
return len(self.data)
112122

0 commit comments

Comments
 (0)