diff --git a/pandas/core/dtypes/missing.py b/pandas/core/dtypes/missing.py index 75188ad5b00eb..8551ce9f14e6c 100644 --- a/pandas/core/dtypes/missing.py +++ b/pandas/core/dtypes/missing.py @@ -355,7 +355,9 @@ def _isna_compat(arr, fill_value=np.nan) -> bool: return True -def array_equivalent(left, right, strict_nan: bool = False) -> bool: +def array_equivalent( + left, right, strict_nan: bool = False, dtype_equal: bool = False +) -> bool: """ True if two arrays, left and right, have equal non-NaN elements, and NaNs in corresponding locations. False otherwise. It is assumed that left and @@ -368,6 +370,12 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool: left, right : ndarrays strict_nan : bool, default False If True, consider NaN and None to be different. + dtype_equal : bool, default False + Whether `left` and `right` are known to have the same dtype + according to `is_dtype_equal`. Some methods like `BlockManager.equals`. + require that the dtypes match. Setting this to ``True`` can improve + performance, but will give different results for arrays that are + equal but different dtypes. Returns ------- @@ -391,43 +399,28 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool: if left.shape != right.shape: return False + if dtype_equal: + # fastpath when we require that the dtypes match (Block.equals) + if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype): + return _array_equivalent_float(left, right) + elif is_datetimelike_v_numeric(left.dtype, right.dtype): + return False + elif needs_i8_conversion(left.dtype): + return _array_equivalent_datetimelike(left, right) + elif is_string_dtype(left.dtype): + # TODO: fastpath for pandas' StringDtype + return _array_equivalent_object(left, right, strict_nan) + else: + return np.array_equal(left, right) + + # Slow path when we allow comparing different dtypes. # Object arrays can contain None, NaN and NaT. # string dtypes must be come to this path for NumPy 1.7.1 compat if is_string_dtype(left.dtype) or is_string_dtype(right.dtype): - - if not strict_nan: - # isna considers NaN and None to be equivalent. - return lib.array_equivalent_object( - ensure_object(left.ravel()), ensure_object(right.ravel()) - ) - - for left_value, right_value in zip(left, right): - if left_value is NaT and right_value is not NaT: - return False - - elif left_value is libmissing.NA and right_value is not libmissing.NA: - return False - - elif isinstance(left_value, float) and np.isnan(left_value): - if not isinstance(right_value, float) or not np.isnan(right_value): - return False - else: - try: - if np.any(np.asarray(left_value != right_value)): - return False - except TypeError as err: - if "Cannot compare tz-naive" in str(err): - # tzawareness compat failure, see GH#28507 - return False - elif "boolean value of NA is ambiguous" in str(err): - return False - raise - return True + return _array_equivalent_object(left, right, strict_nan) # NaNs can occur in float and complex arrays. if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype): - - # empty if not (np.prod(left.shape) and np.prod(right.shape)): return True return ((left == right) | (isna(left) & isna(right))).all() @@ -452,6 +445,45 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool: return np.array_equal(left, right) +def _array_equivalent_float(left, right): + return ((left == right) | (np.isnan(left) & np.isnan(right))).all() + + +def _array_equivalent_datetimelike(left, right): + return np.array_equal(left.view("i8"), right.view("i8")) + + +def _array_equivalent_object(left, right, strict_nan): + if not strict_nan: + # isna considers NaN and None to be equivalent. + return lib.array_equivalent_object( + ensure_object(left.ravel()), ensure_object(right.ravel()) + ) + + for left_value, right_value in zip(left, right): + if left_value is NaT and right_value is not NaT: + return False + + elif left_value is libmissing.NA and right_value is not libmissing.NA: + return False + + elif isinstance(left_value, float) and np.isnan(left_value): + if not isinstance(right_value, float) or not np.isnan(right_value): + return False + else: + try: + if np.any(np.asarray(left_value != right_value)): + return False + except TypeError as err: + if "Cannot compare tz-naive" in str(err): + # tzawareness compat failure, see GH#28507 + return False + elif "boolean value of NA is ambiguous" in str(err): + return False + raise + return True + + def _infer_fill_value(val): """ infer the fill value for the nan/NaT from the provided diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index d5947726af7fd..895385b170c91 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -1436,7 +1436,7 @@ def equals(self, other: "BlockManager") -> bool: return array_equivalent(left, right) for i in range(len(self.items)): - # Check column-wise, return False if any column doesnt match + # Check column-wise, return False if any column doesn't match left = self.iget_values(i) right = other.iget_values(i) if not is_dtype_equal(left.dtype, right.dtype): @@ -1445,7 +1445,7 @@ def equals(self, other: "BlockManager") -> bool: if not left.equals(right): return False else: - if not array_equivalent(left, right): + if not array_equivalent(left, right, dtype_equal=True): return False return True diff --git a/pandas/tests/dtypes/test_missing.py b/pandas/tests/dtypes/test_missing.py index f9a854c5778a2..04dde08de082d 100644 --- a/pandas/tests/dtypes/test_missing.py +++ b/pandas/tests/dtypes/test_missing.py @@ -300,50 +300,80 @@ def test_period(self): tm.assert_series_equal(notna(s), ~exp) -def test_array_equivalent(): - assert array_equivalent(np.array([np.nan, np.nan]), np.array([np.nan, np.nan])) +@pytest.mark.parametrize("dtype_equal", [True, False]) +def test_array_equivalent(dtype_equal): assert array_equivalent( - np.array([np.nan, 1, np.nan]), np.array([np.nan, 1, np.nan]) + np.array([np.nan, np.nan]), np.array([np.nan, np.nan]), dtype_equal=dtype_equal + ) + assert array_equivalent( + np.array([np.nan, 1, np.nan]), + np.array([np.nan, 1, np.nan]), + dtype_equal=dtype_equal, ) assert array_equivalent( np.array([np.nan, None], dtype="object"), np.array([np.nan, None], dtype="object"), + dtype_equal=dtype_equal, ) # Check the handling of nested arrays in array_equivalent_object assert array_equivalent( np.array([np.array([np.nan, None], dtype="object"), None], dtype="object"), np.array([np.array([np.nan, None], dtype="object"), None], dtype="object"), + dtype_equal=dtype_equal, ) assert array_equivalent( np.array([np.nan, 1 + 1j], dtype="complex"), np.array([np.nan, 1 + 1j], dtype="complex"), + dtype_equal=dtype_equal, ) assert not array_equivalent( np.array([np.nan, 1 + 1j], dtype="complex"), np.array([np.nan, 1 + 2j], dtype="complex"), + dtype_equal=dtype_equal, + ) + assert not array_equivalent( + np.array([np.nan, 1, np.nan]), + np.array([np.nan, 2, np.nan]), + dtype_equal=dtype_equal, + ) + assert not array_equivalent( + np.array(["a", "b", "c", "d"]), np.array(["e", "e"]), dtype_equal=dtype_equal + ) + assert array_equivalent( + Float64Index([0, np.nan]), Float64Index([0, np.nan]), dtype_equal=dtype_equal ) assert not array_equivalent( - np.array([np.nan, 1, np.nan]), np.array([np.nan, 2, np.nan]) + Float64Index([0, np.nan]), Float64Index([1, np.nan]), dtype_equal=dtype_equal + ) + assert array_equivalent( + DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan]), dtype_equal=dtype_equal + ) + assert not array_equivalent( + DatetimeIndex([0, np.nan]), DatetimeIndex([1, np.nan]), dtype_equal=dtype_equal + ) + assert array_equivalent( + TimedeltaIndex([0, np.nan]), + TimedeltaIndex([0, np.nan]), + dtype_equal=dtype_equal, ) - assert not array_equivalent(np.array(["a", "b", "c", "d"]), np.array(["e", "e"])) - assert array_equivalent(Float64Index([0, np.nan]), Float64Index([0, np.nan])) - assert not array_equivalent(Float64Index([0, np.nan]), Float64Index([1, np.nan])) - assert array_equivalent(DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan])) - assert not array_equivalent(DatetimeIndex([0, np.nan]), DatetimeIndex([1, np.nan])) - assert array_equivalent(TimedeltaIndex([0, np.nan]), TimedeltaIndex([0, np.nan])) assert not array_equivalent( - TimedeltaIndex([0, np.nan]), TimedeltaIndex([1, np.nan]) + TimedeltaIndex([0, np.nan]), + TimedeltaIndex([1, np.nan]), + dtype_equal=dtype_equal, ) assert array_equivalent( DatetimeIndex([0, np.nan], tz="US/Eastern"), DatetimeIndex([0, np.nan], tz="US/Eastern"), + dtype_equal=dtype_equal, ) assert not array_equivalent( DatetimeIndex([0, np.nan], tz="US/Eastern"), DatetimeIndex([1, np.nan], tz="US/Eastern"), + dtype_equal=dtype_equal, ) + # The rest are not dtype_equal assert not array_equivalent( - DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan], tz="US/Eastern") + DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan], tz="US/Eastern"), ) assert not array_equivalent( DatetimeIndex([0, np.nan], tz="CET"), @@ -353,6 +383,11 @@ def test_array_equivalent(): assert not array_equivalent(DatetimeIndex([0, np.nan]), TimedeltaIndex([0, np.nan])) +def test_array_equivalent_different_dtype_but_equal(): + # Unclear if this is exposed anywhere in the public-facing API + assert array_equivalent(np.array([1, 2]), np.array([1.0, 2.0])) + + @pytest.mark.parametrize( "lvalue, rvalue", [