Skip to content

Commit 786963c

Browse files
committed
ENH: Updating eq and ne methods for extension arrays.
Adding __eq__ to ExtensionArray Abstract method doc string. Adding ne implementation to EA base class. Also removing other implementations. Updating EA equals method and adding tests. pandas-devGH-27081
1 parent 36c8b88 commit 786963c

File tree

9 files changed

+30
-58
lines changed

9 files changed

+30
-58
lines changed

pandas/core/arrays/base.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class ExtensionArray:
105105
* _from_sequence
106106
* _from_factorized
107107
* __getitem__
108+
* __eq__
108109
* __len__
109110
* dtype
110111
* nbytes
@@ -381,7 +382,7 @@ def __ne__(self, other: ABCExtensionArray) -> bool:
381382
bool
382383
"""
383384

384-
raise AbstractMethodError(self)
385+
return ~(self == other)
385386

386387
# ------------------------------------------------------------------------
387388
# Required attributes
@@ -705,7 +706,9 @@ def equals(self, other: ABCExtensionArray) -> bool:
705706
Whether the arrays are equivalent.
706707
707708
"""
708-
return ((self == other) | (self.isna() == other.isna())).all()
709+
return isinstance(other, self.__class__) and (
710+
((self == other) | (self.isna() == other.isna())).all()
711+
)
709712

710713
def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
711714
"""

pandas/core/arrays/boolean.py

-14
Original file line numberDiff line numberDiff line change
@@ -327,20 +327,6 @@ def __getitem__(self, item):
327327

328328
return type(self)(self._data[item], self._mask[item])
329329

330-
def __eq__(self, other):
331-
if not isinstance(other, BooleanArray):
332-
return NotImplemented
333-
return (
334-
hasattr(other, "_data")
335-
and self._data == other._data
336-
and hasattr(other, "_mask")
337-
and self._mask == other._mask
338-
and hasattr(other, "_dtype") & self._dtype == other._dtype
339-
)
340-
341-
def __ne__(self, other):
342-
return not self.__eq__(other)
343-
344330
def _coerce_to_ndarray(self, dtype=None, na_value: "Scalar" = libmissing.NA):
345331
"""
346332
Coerce to an ndarray of object dtype or bool dtype (if force_bool=True).

pandas/core/arrays/categorical.py

-13
Original file line numberDiff line numberDiff line change
@@ -2071,19 +2071,6 @@ def __setitem__(self, key, value):
20712071
lindexer = self._maybe_coerce_indexer(lindexer)
20722072
self._codes[key] = lindexer
20732073

2074-
def __eq__(self, other):
2075-
if not isinstance(other, Categorical):
2076-
return NotImplemented
2077-
return (
2078-
hasattr(other, "_codes")
2079-
and self._codes == other._codes
2080-
and hasattr(other, "_dtype")
2081-
and self._dtype == other._dtype
2082-
)
2083-
2084-
def __ne__(self, other):
2085-
return not self.__eq__(other)
2086-
20872074
def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]:
20882075
"""
20892076
Compute the inverse of a categorical, returning

pandas/core/arrays/integer.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -377,18 +377,14 @@ def __getitem__(self, item):
377377
return type(self)(self._data[item], self._mask[item])
378378

379379
def __eq__(self, other):
380-
if not isinstance(other, IntegerArray):
381-
return NotImplemented
382380
return (
383-
hasattr(other, "_data")
381+
isinstance(other, IntegerArray)
382+
and hasattr(other, "_data")
384383
and self._data == other._data
385384
and hasattr(other, "_mask")
386385
and self._mask == other._mask
387386
)
388387

389-
def __ne__(self, other):
390-
return not self.__eq__(other)
391-
392388
def _coerce_to_ndarray(self, dtype=None, na_value=lib._no_default):
393389
"""
394390
coerce to an ndarary of object dtype

pandas/core/arrays/interval.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -548,20 +548,16 @@ def __setitem__(self, key, value):
548548
self._right = right
549549

550550
def __eq__(self, other):
551-
if not isinstance(other, IntervalArray):
552-
return NotImplementedError
553551
return (
554-
hasattr(other, "_left")
555-
and self._left == other._left
552+
isinstance(other, IntervalArray)
553+
and hasattr(other, "_left")
554+
and np.array_equal(self._left, other._left)
556555
and hasattr(other, "_right")
557-
and self._right == other._right
556+
and np.array_equal(self._right, other._right)
558557
and hasattr(other, "_closed")
559558
and self._closed == other._closed
560559
)
561560

562-
def __ne__(self, other):
563-
return not self.__eq__(other)
564-
565561
def fillna(self, value=None, method=None, limit=None):
566562
"""
567563
Fill NA/NaN values using the specified method.

pandas/tests/extension/base/methods.py

+15
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,18 @@ def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy):
358358
np.repeat(data, repeats, **kwargs)
359359
else:
360360
data.repeat(repeats, **kwargs)
361+
362+
def test_equals(self, data, na_value):
363+
cls = type(data)
364+
ser = pd.Series(cls._from_sequence(data, dtype=data.dtype))
365+
na_ser = pd.Series(cls._from_sequence([na_value], dtype=data.dtype))
366+
367+
assert data.equals(data)
368+
assert ser.equals(ser)
369+
assert na_ser.equals(na_ser)
370+
371+
assert not data.equals(na_value)
372+
assert not na_ser.equals(ser)
373+
assert not ser.equals(na_ser)
374+
assert not ser.equals(0)
375+
assert not na_ser.equals(0)

pandas/tests/extension/base/ops.py

-12
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,8 @@ class BaseComparisonOpsTests(BaseOpsUtil):
132132
def _compare_other(self, s, data, op_name, other):
133133
op = self.get_op_from_name(op_name)
134134
if op_name == "__eq__":
135-
assert getattr(data, op_name)(other) is NotImplemented
136135
assert not op(s, other).all()
137136
elif op_name == "__ne__":
138-
assert getattr(data, op_name)(other) is NotImplemented
139137
assert op(s, other).all()
140138

141139
else:
@@ -158,13 +156,3 @@ def test_compare_array(self, data, all_compare_operators):
158156
s = pd.Series(data)
159157
other = pd.Series([data[0]] * len(data))
160158
self._compare_other(s, data, op_name, other)
161-
162-
def test_direct_arith_with_series_returns_not_implemented(self, data):
163-
# EAs should return NotImplemented for ops with Series.
164-
# Pandas takes care of unboxing the series and calling the EA's op.
165-
other = pd.Series(data)
166-
if hasattr(data, "__eq__"):
167-
result = data.__eq__(other)
168-
assert result is NotImplemented
169-
else:
170-
raise pytest.skip(f"{type(data).__name__} does not implement __eq__")

pandas/tests/extension/json/array.py

-3
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,6 @@ def __eq__(self, other):
117117
and self.data == other.data
118118
)
119119

120-
def __ne__(self, other):
121-
return not self.__eq__(other)
122-
123120
def __len__(self) -> int:
124121
return len(self.data)
125122

pandas/tests/extension/test_sparse.py

+4
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,10 @@ def test_searchsorted(self, data_for_sorting, as_series):
302302
with tm.assert_produces_warning(PerformanceWarning):
303303
super().test_searchsorted(data_for_sorting, as_series)
304304

305+
def test_equals(self, data, na_value):
306+
self._check_unsupported(data)
307+
super().test_equals(data, na_value)
308+
305309

306310
class TestCasting(BaseSparseTests, base.BaseCastingTests):
307311
pass

0 commit comments

Comments
 (0)