diff --git a/doc/source/reference/extensions.rst b/doc/source/reference/extensions.rst index 4c0763e091b75..fe4113d100abf 100644 --- a/doc/source/reference/extensions.rst +++ b/doc/source/reference/extensions.rst @@ -45,6 +45,7 @@ objects. api.extensions.ExtensionArray.copy api.extensions.ExtensionArray.view api.extensions.ExtensionArray.dropna + api.extensions.ExtensionArray.equals api.extensions.ExtensionArray.factorize api.extensions.ExtensionArray.fillna api.extensions.ExtensionArray.isna diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 9c424f70b1ee0..3ce0db2cf38d0 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -150,6 +150,8 @@ Other enhancements such as ``dict`` and ``list``, mirroring the behavior of :meth:`DataFrame.update` (:issue:`33215`) - :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`) - :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`) +- The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals` + method, similarly to :meth:`Series.equals` (:issue:`27081`). - .. --------------------------------------------------------------------------- diff --git a/pandas/_testing.py b/pandas/_testing.py index 8fbdcb89dafca..6424a8097fbf1 100644 --- a/pandas/_testing.py +++ b/pandas/_testing.py @@ -1490,7 +1490,9 @@ def box_expected(expected, box_cls, transpose=True): ------- subclass of box_cls """ - if box_cls is pd.Index: + if box_cls is pd.array: + expected = pd.array(expected) + elif box_cls is pd.Index: expected = pd.Index(expected) elif box_cls is pd.Series: expected = pd.Series(expected) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index bd903d9b1fae3..0c5634a932e12 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -58,6 +58,7 @@ class ExtensionArray: dropna factorize fillna + equals isna ravel repeat @@ -84,6 +85,7 @@ class ExtensionArray: * _from_factorized * __getitem__ * __len__ + * __eq__ * dtype * nbytes * isna @@ -333,6 +335,24 @@ def __iter__(self): for i in range(len(self)): yield self[i] + def __eq__(self, other: Any) -> ArrayLike: + """ + Return for `self == other` (element-wise equality). + """ + # Implementer note: this should return a boolean numpy ndarray or + # a boolean ExtensionArray. + # When `other` is one of Series, Index, or DataFrame, this method should + # return NotImplemented (to ensure that those objects are responsible for + # first unpacking the arrays, and then dispatch the operation to the + # underlying arrays) + raise AbstractMethodError(self) + + def __ne__(self, other: Any) -> ArrayLike: + """ + Return for `self != other` (element-wise in-equality). + """ + return ~(self == other) + def to_numpy( self, dtype=None, copy: bool = False, na_value=lib.no_default ) -> np.ndarray: @@ -682,6 +702,38 @@ def searchsorted(self, value, side="left", sorter=None): arr = self.astype(object) return arr.searchsorted(value, side=side, sorter=sorter) + def equals(self, other: "ExtensionArray") -> bool: + """ + Return if another array is equivalent to this array. + + Equivalent means that both arrays have the same shape and dtype, and + all values compare equal. Missing values in the same location are + considered equal (in contrast with normal equality). + + Parameters + ---------- + other : ExtensionArray + Array to compare to this Array. + + Returns + ------- + boolean + Whether the arrays are equivalent. + """ + if not type(self) == type(other): + return False + elif not self.dtype == other.dtype: + return False + elif not len(self) == len(other): + return False + else: + equal_values = self == other + if isinstance(equal_values, ExtensionArray): + # boolean array with NA -> fill with False + equal_values = equal_values.fillna(False) + equal_na = self.isna() & other.isna() + return (equal_values | equal_na).all().item() + def _values_for_factorize(self) -> Tuple[np.ndarray, Any]: """ Return an array and missing value suitable for factorization. @@ -1134,7 +1186,7 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin): """ @classmethod - def _create_method(cls, op, coerce_to_dtype=True): + def _create_method(cls, op, coerce_to_dtype=True, result_dtype=None): """ A class method that returns a method that will correspond to an operator for an ExtensionArray subclass, by dispatching to the @@ -1202,7 +1254,7 @@ def _maybe_convert(arr): # exception raised in _from_sequence; ensure we have ndarray res = np.asarray(arr) else: - res = np.asarray(arr) + res = np.asarray(arr, dtype=result_dtype) return res if op.__name__ in {"divmod", "rdivmod"}: @@ -1220,4 +1272,4 @@ def _create_arithmetic_method(cls, op): @classmethod def _create_comparison_method(cls, op): - return cls._create_method(op, coerce_to_dtype=False) + return cls._create_method(op, coerce_to_dtype=False, result_dtype=bool) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 66faca29670cb..8cac909b70802 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -606,9 +606,6 @@ def __eq__(self, other): return result - def __ne__(self, other): - return ~self.__eq__(other) - def fillna(self, value=None, method=None, limit=None): """ Fill NA/NaN values using the specified method. diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index e4dcffae45f67..d22adf2aaf179 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1864,6 +1864,9 @@ def where( return [self.make_block_same_class(result, placement=self.mgr_locs)] + def equals(self, other) -> bool: + return self.values.equals(other.values) + def _unstack(self, unstacker, fill_value, new_placement): # ExtensionArray-safe unstack. # We override ObjectBlock._unstack, which unstacks directly on the diff --git a/pandas/tests/arrays/integer/test_comparison.py b/pandas/tests/arrays/integer/test_comparison.py index d76ed2c21ca0e..1767250af09b0 100644 --- a/pandas/tests/arrays/integer/test_comparison.py +++ b/pandas/tests/arrays/integer/test_comparison.py @@ -104,3 +104,13 @@ def test_compare_to_int(self, any_nullable_int_dtype, all_compare_operators): expected[s2.isna()] = pd.NA self.assert_series_equal(result, expected) + + +def test_equals(): + # GH-30652 + # equals is generally tested in /tests/extension/base/methods, but this + # specifically tests that two arrays of the same class but different dtype + # do not evaluate equal + a1 = pd.array([1, 2, None], dtype="Int64") + a2 = pd.array([1, 2, None], dtype="Int32") + assert a1.equals(a2) is False diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index ca92c2e1e318d..4a6d827b36b02 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -421,3 +421,32 @@ def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy): np.repeat(data, repeats, **kwargs) else: data.repeat(repeats, **kwargs) + + @pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame]) + def test_equals(self, data, na_value, as_series, box): + data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype) + data_na = type(data)._from_sequence([na_value] * len(data), dtype=data.dtype) + + data = tm.box_expected(data, box, transpose=False) + data2 = tm.box_expected(data2, box, transpose=False) + data_na = tm.box_expected(data_na, box, transpose=False) + + # we are asserting with `is True/False` explicitly, to test that the + # result is an actual Python bool, and not something "truthy" + + assert data.equals(data) is True + assert data.equals(data.copy()) is True + + # unequal other data + assert data.equals(data2) is False + assert data.equals(data_na) is False + + # different length + assert data[:2].equals(data[:3]) is False + + # emtpy are equal + assert data[:0].equals(data[:0]) is True + + # other types + assert data.equals(None) is False + assert data[[0]].equals(data[0]) is False diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index d3b6472044ea5..188893c8b067c 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -139,10 +139,8 @@ class BaseComparisonOpsTests(BaseOpsUtil): def _compare_other(self, s, data, op_name, other): op = self.get_op_from_name(op_name) if op_name == "__eq__": - assert getattr(data, op_name)(other) is NotImplemented assert not op(s, other).all() elif op_name == "__ne__": - assert getattr(data, op_name)(other) is NotImplemented assert op(s, other).all() else: @@ -176,6 +174,12 @@ def test_direct_arith_with_series_returns_not_implemented(self, data): else: raise pytest.skip(f"{type(data).__name__} does not implement __eq__") + if hasattr(data, "__ne__"): + result = data.__ne__(other) + assert result is NotImplemented + else: + raise pytest.skip(f"{type(data).__name__} does not implement __ne__") + class BaseUnaryOpsTests(BaseOpsUtil): def test_invert(self, data): diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 1f026e405dc17..94f971938b690 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -105,6 +105,12 @@ def __setitem__(self, key, value): def __len__(self) -> int: return len(self.data) + def __eq__(self, other): + return NotImplemented + + def __ne__(self, other): + return NotImplemented + def __array__(self, dtype=None): if dtype is None: dtype = object diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index d79769208ab56..74ca341e27bf8 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -262,6 +262,10 @@ def test_where_series(self, data, na_value): def test_searchsorted(self, data_for_sorting): super().test_searchsorted(data_for_sorting) + @pytest.mark.skip(reason="Can't compare dicts.") + def test_equals(self, data, na_value, as_series): + pass + class TestCasting(BaseJSON, base.BaseCastingTests): @pytest.mark.skip(reason="failing on np.array(self, dtype=str)") diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index e48065b47f17c..1e21249988df6 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -276,6 +276,12 @@ def test_repeat(self, data, repeats, as_series, use_numpy): def test_diff(self, data, periods): return super().test_diff(data, periods) + @skip_nested + @pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame]) + def test_equals(self, data, na_value, as_series, box): + # Fails creating with _from_sequence + super().test_equals(data, na_value, as_series, box) + @skip_nested class TestArithmetics(BaseNumPyTests, base.BaseArithmeticOpsTests): diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index 19ac25eb0ccf7..e59b3f0600867 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -316,6 +316,11 @@ def test_shift_0_periods(self, data): data._sparse_values[0] = data._sparse_values[1] assert result._sparse_values[0] != result._sparse_values[1] + @pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame]) + def test_equals(self, data, na_value, as_series, box): + self._check_unsupported(data) + super().test_equals(data, na_value, as_series, box) + class TestCasting(BaseSparseTests, base.BaseCastingTests): def test_astype_object_series(self, all_data):