diff --git a/pandas/_libs/sparse.pyx b/pandas/_libs/sparse.pyx index 321d7c374d8ec..0c3d8915b749b 100644 --- a/pandas/_libs/sparse.pyx +++ b/pandas/_libs/sparse.pyx @@ -103,7 +103,7 @@ cdef class IntIndex(SparseIndex): if not monotonic: raise ValueError("Indices must be strictly increasing") - def equals(self, other) -> bool: + def equals(self, other: object) -> bool: if not isinstance(other, IntIndex): return False @@ -399,7 +399,7 @@ cdef class BlockIndex(SparseIndex): if blengths[i] == 0: raise ValueError(f'Zero-length block {i}') - def equals(self, other) -> bool: + def equals(self, other: object) -> bool: if not isinstance(other, BlockIndex): return False diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 921927325a144..d85647edc3b81 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -7,7 +7,7 @@ without warning. """ import operator -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union, cast import numpy as np @@ -20,7 +20,12 @@ from pandas.util._validators import validate_fillna_kwargs from pandas.core.dtypes.cast import maybe_cast_to_extension_array -from pandas.core.dtypes.common import is_array_like, is_list_like, pandas_dtype +from pandas.core.dtypes.common import ( + is_array_like, + is_dtype_equal, + is_list_like, + pandas_dtype, +) from pandas.core.dtypes.dtypes import ExtensionDtype from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries from pandas.core.dtypes.missing import isna @@ -742,7 +747,7 @@ def searchsorted(self, value, side="left", sorter=None): arr = self.astype(object) return arr.searchsorted(value, side=side, sorter=sorter) - def equals(self, other: "ExtensionArray") -> bool: + def equals(self, other: object) -> bool: """ Return if another array is equivalent to this array. @@ -762,7 +767,8 @@ def equals(self, other: "ExtensionArray") -> bool: """ if not type(self) == type(other): return False - elif not self.dtype == other.dtype: + other = cast(ExtensionArray, other) + if not is_dtype_equal(self.dtype, other.dtype): return False elif not len(self) == len(other): return False diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 6e5c7bc699962..a28b341669918 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2242,7 +2242,7 @@ def _from_factorized(cls, uniques, original): original.categories.take(uniques), dtype=original.dtype ) - def equals(self, other): + def equals(self, other: object) -> bool: """ Returns True if categorical arrays are equal. @@ -2254,7 +2254,9 @@ def equals(self, other): ------- bool """ - if self.is_dtype_equal(other): + if not isinstance(other, Categorical): + return False + elif self.is_dtype_equal(other): if self.categories.equals(other.categories): # fastpath to avoid re-coding other_codes = other._codes diff --git a/pandas/core/generic.py b/pandas/core/generic.py index fcb7e2a949205..651c079ecc08e 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -22,6 +22,7 @@ Tuple, Type, Union, + cast, ) import warnings import weakref @@ -1195,7 +1196,7 @@ def _indexed_same(self, other) -> bool: self._get_axis(a).equals(other._get_axis(a)) for a in self._AXIS_ORDERS ) - def equals(self, other): + def equals(self, other: object) -> bool: """ Test whether two objects contain the same elements. @@ -1275,6 +1276,7 @@ def equals(self, other): """ if not (isinstance(other, type(self)) or isinstance(self, type(other))): return False + other = cast(NDFrame, other) return self._mgr.equals(other._mgr) # ------------------------------------------------------------------------- diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index ecd3670e724a1..623ce68201492 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4176,7 +4176,7 @@ def putmask(self, mask, value): # coerces to object return self.astype(object).putmask(mask, value) - def equals(self, other: Any) -> bool: + def equals(self, other: object) -> bool: """ Determine if two Index object are equal. diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index fb283cbe02954..4990e6a8e20e9 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -290,7 +290,7 @@ def _is_dtype_compat(self, other) -> bool: return other - def equals(self, other) -> bool: + def equals(self, other: object) -> bool: """ Determine if two CategoricalIndex objects contain the same elements. diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 8ccdab21339df..6d9d75a69e91d 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -24,7 +24,7 @@ is_scalar, ) from pandas.core.dtypes.concat import concat_compat -from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries +from pandas.core.dtypes.generic import ABCIndex, ABCSeries from pandas.core import algorithms from pandas.core.arrays import DatetimeArray, PeriodArray, TimedeltaArray @@ -130,14 +130,14 @@ def __array_wrap__(self, result, context=None): # ------------------------------------------------------------------------ - def equals(self, other) -> bool: + def equals(self, other: object) -> bool: """ Determines if two Index objects contain the same elements. """ if self.is_(other): return True - if not isinstance(other, ABCIndexClass): + if not isinstance(other, Index): return False elif not isinstance(other, type(self)): try: diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 9548ebbd9c3b2..e8d0a44324cc5 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -1005,19 +1005,20 @@ def _format_space(self) -> str: def argsort(self, *args, **kwargs) -> np.ndarray: return np.lexsort((self.right, self.left)) - def equals(self, other) -> bool: + def equals(self, other: object) -> bool: """ Determines if two IntervalIndex objects contain the same elements. """ if self.is_(other): return True - # if we can coerce to an II - # then we can compare + # if we can coerce to an IntervalIndex then we can compare if not isinstance(other, IntervalIndex): if not is_interval_dtype(other): return False other = Index(other) + if not isinstance(other, IntervalIndex): + return False return ( self.left.equals(other.left) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 13927dede5542..9ca3075d2962c 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3227,7 +3227,7 @@ def truncate(self, before=None, after=None): verify_integrity=False, ) - def equals(self, other) -> bool: + def equals(self, other: object) -> bool: """ Determines if two MultiIndex objects have the same labeling information (the levels themselves do not necessarily have to be the same) @@ -3270,11 +3270,10 @@ def equals(self, other) -> bool: np.asarray(other.levels[i]._values), other_codes, allow_fill=False ) - # since we use NaT both datetime64 and timedelta64 - # we can have a situation where a level is typed say - # timedelta64 in self (IOW it has other values than NaT) - # but types datetime64 in other (where its all NaT) - # but these are equivalent + # since we use NaT both datetime64 and timedelta64 we can have a + # situation where a level is typed say timedelta64 in self (IOW it + # has other values than NaT) but types datetime64 in other (where + # its all NaT) but these are equivalent if len(self_values) == 0 and len(other_values) == 0: continue diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index 3577a7aacc008..6080c32052266 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -433,7 +433,7 @@ def argsort(self, *args, **kwargs) -> np.ndarray: else: return np.arange(len(self) - 1, -1, -1) - def equals(self, other) -> bool: + def equals(self, other: object) -> bool: """ Determines if two Index objects contain the same elements. """ diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index aa74d173d69b3..371b721f08b27 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -1437,7 +1437,10 @@ def take(self, indexer, axis: int = 1, verify: bool = True, convert: bool = True new_axis=new_labels, indexer=indexer, axis=axis, allow_dups=True ) - def equals(self, other: "BlockManager") -> bool: + def equals(self, other: object) -> bool: + if not isinstance(other, BlockManager): + return False + self_axes, other_axes = self.axes, other.axes if len(self_axes) != len(other_axes): return False