Skip to content

Optimize array_equivalent for NDFrame.equals #35328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 64 additions & 32 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def _isna_compat(arr, fill_value=np.nan) -> bool:
return True


def array_equivalent(left, right, strict_nan: bool = False) -> bool:
def array_equivalent(
left, right, strict_nan: bool = False, dtype_equal: bool = False
) -> bool:
"""
True if two arrays, left and right, have equal non-NaN elements, and NaNs
in corresponding locations. False otherwise. It is assumed that left and
Expand All @@ -368,6 +370,12 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool:
left, right : ndarrays
strict_nan : bool, default False
If True, consider NaN and None to be different.
dtype_equal : bool, default False
Whether `left` and `right` are known to have the same dtype
according to `is_dtype_equal`. Some methods like `BlockManager.equals`.
require that the dtypes match. Setting this to ``True`` can improve
performance, but will give different results for arrays that are
equal but different dtypes.

Returns
-------
Expand All @@ -391,43 +399,28 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool:
if left.shape != right.shape:
return False

if dtype_equal:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am not sure i understand why you actually need this
the dtype check is cheap compared to the actual check no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. They add up, especially when doing things columnwise for small, wide dataframes
  2. We get to skip a few other checks, like swapping np.isnan for pd.isna for float, and we can skip the empty check at https://github.com/pandas-dev/pandas/pull/35328/files#diff-ff8364cee9a3e1ef3a3825cb2cdd26d8L431.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok that's fine

# fastpath when we require that the dtypes match (Block.equals)
if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype):
return array_equivalent_float(left, right)
elif is_datetimelike_v_numeric(left.dtype, right.dtype):
return False
elif needs_i8_conversion(left.dtype):
return array_equivalent_datetimelike(left, right)
elif is_string_dtype(left.dtype):
# TODO: fastpath for pandas' StringDtype
return array_equivalent_object(left, right, strict_nan)
else:
return np.array_equal(left, right)

# Slow path when we allow comparing different dtypes.
# Object arrays can contain None, NaN and NaT.
# string dtypes must be come to this path for NumPy 1.7.1 compat
if is_string_dtype(left.dtype) or is_string_dtype(right.dtype):

if not strict_nan:
# isna considers NaN and None to be equivalent.
return lib.array_equivalent_object(
ensure_object(left.ravel()), ensure_object(right.ravel())
)

for left_value, right_value in zip(left, right):
if left_value is NaT and right_value is not NaT:
return False

elif left_value is libmissing.NA and right_value is not libmissing.NA:
return False

elif isinstance(left_value, float) and np.isnan(left_value):
if not isinstance(right_value, float) or not np.isnan(right_value):
return False
else:
try:
if np.any(np.asarray(left_value != right_value)):
return False
except TypeError as err:
if "Cannot compare tz-naive" in str(err):
# tzawareness compat failure, see GH#28507
return False
elif "boolean value of NA is ambiguous" in str(err):
return False
raise
return True
return array_equivalent_object(left, right, strict_nan)

# NaNs can occur in float and complex arrays.
if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype):

# empty
if not (np.prod(left.shape) and np.prod(right.shape)):
return True
return ((left == right) | (isna(left) & isna(right))).all()
Expand All @@ -452,6 +445,45 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool:
return np.array_equal(left, right)


def array_equivalent_float(left, right):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe make all of these helpers private functions

return ((left == right) | (np.isnan(left) & np.isnan(right))).all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about use_inf_as_na?

Can this be re-used on L426? (also the prod(shape) check on L424-425 i think is redundant with earlier shape check)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like 1.0.5 doesn't care about that

In [15]: df1 = pd.DataFrame({"A": np.array([np.nan, 1, np.inf])})

In [16]: df2 = pd.DataFrame({"A": np.array([np.nan, 1, np.nan])})

In [17]: with pd.option_context('mode.use_inf_as_na', True):
    ...:     print(df1.equals(df2))
    ...:
    ...:
False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for checking



def array_equivalent_datetimelike(left, right):
return np.array_equal(left.view("i8"), right.view("i8"))


def array_equivalent_object(left, right, strict_nan):
if not strict_nan:
# isna considers NaN and None to be equivalent.
return lib.array_equivalent_object(
ensure_object(left.ravel()), ensure_object(right.ravel())
)

for left_value, right_value in zip(left, right):
if left_value is NaT and right_value is not NaT:
return False

elif left_value is libmissing.NA and right_value is not libmissing.NA:
return False

elif isinstance(left_value, float) and np.isnan(left_value):
if not isinstance(right_value, float) or not np.isnan(right_value):
return False
else:
try:
if np.any(np.asarray(left_value != right_value)):
return False
except TypeError as err:
if "Cannot compare tz-naive" in str(err):
# tzawareness compat failure, see GH#28507
return False
elif "boolean value of NA is ambiguous" in str(err):
return False
raise
return True


def _infer_fill_value(val):
"""
infer the fill value for the nan/NaT from the provided
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ def equals(self, other: "BlockManager") -> bool:
return array_equivalent(left, right)

for i in range(len(self.items)):
# Check column-wise, return False if any column doesnt match
# Check column-wise, return False if any column doesn't match
left = self.iget_values(i)
right = other.iget_values(i)
if not is_dtype_equal(left.dtype, right.dtype):
Expand All @@ -1445,7 +1445,7 @@ def equals(self, other: "BlockManager") -> bool:
if not left.equals(right):
return False
else:
if not array_equivalent(left, right):
if not array_equivalent(left, right, dtype_equal=True):
return False
return True

Expand Down
59 changes: 47 additions & 12 deletions pandas/tests/dtypes/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,50 +300,80 @@ def test_period(self):
tm.assert_series_equal(notna(s), ~exp)


def test_array_equivalent():
assert array_equivalent(np.array([np.nan, np.nan]), np.array([np.nan, np.nan]))
@pytest.mark.parametrize("dtype_equal", [True, False])
def test_array_equivalent(dtype_equal):
assert array_equivalent(
np.array([np.nan, 1, np.nan]), np.array([np.nan, 1, np.nan])
np.array([np.nan, np.nan]), np.array([np.nan, np.nan]), dtype_equal=dtype_equal
)
assert array_equivalent(
np.array([np.nan, 1, np.nan]),
np.array([np.nan, 1, np.nan]),
dtype_equal=dtype_equal,
)
assert array_equivalent(
np.array([np.nan, None], dtype="object"),
np.array([np.nan, None], dtype="object"),
dtype_equal=dtype_equal,
)
# Check the handling of nested arrays in array_equivalent_object
assert array_equivalent(
np.array([np.array([np.nan, None], dtype="object"), None], dtype="object"),
np.array([np.array([np.nan, None], dtype="object"), None], dtype="object"),
dtype_equal=dtype_equal,
)
assert array_equivalent(
np.array([np.nan, 1 + 1j], dtype="complex"),
np.array([np.nan, 1 + 1j], dtype="complex"),
dtype_equal=dtype_equal,
)
assert not array_equivalent(
np.array([np.nan, 1 + 1j], dtype="complex"),
np.array([np.nan, 1 + 2j], dtype="complex"),
dtype_equal=dtype_equal,
)
assert not array_equivalent(
np.array([np.nan, 1, np.nan]),
np.array([np.nan, 2, np.nan]),
dtype_equal=dtype_equal,
)
assert not array_equivalent(
np.array(["a", "b", "c", "d"]), np.array(["e", "e"]), dtype_equal=dtype_equal
)
assert array_equivalent(
Float64Index([0, np.nan]), Float64Index([0, np.nan]), dtype_equal=dtype_equal
)
assert not array_equivalent(
np.array([np.nan, 1, np.nan]), np.array([np.nan, 2, np.nan])
Float64Index([0, np.nan]), Float64Index([1, np.nan]), dtype_equal=dtype_equal
)
assert array_equivalent(
DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan]), dtype_equal=dtype_equal
)
assert not array_equivalent(
DatetimeIndex([0, np.nan]), DatetimeIndex([1, np.nan]), dtype_equal=dtype_equal
)
assert array_equivalent(
TimedeltaIndex([0, np.nan]),
TimedeltaIndex([0, np.nan]),
dtype_equal=dtype_equal,
)
assert not array_equivalent(np.array(["a", "b", "c", "d"]), np.array(["e", "e"]))
assert array_equivalent(Float64Index([0, np.nan]), Float64Index([0, np.nan]))
assert not array_equivalent(Float64Index([0, np.nan]), Float64Index([1, np.nan]))
assert array_equivalent(DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan]))
assert not array_equivalent(DatetimeIndex([0, np.nan]), DatetimeIndex([1, np.nan]))
assert array_equivalent(TimedeltaIndex([0, np.nan]), TimedeltaIndex([0, np.nan]))
assert not array_equivalent(
TimedeltaIndex([0, np.nan]), TimedeltaIndex([1, np.nan])
TimedeltaIndex([0, np.nan]),
TimedeltaIndex([1, np.nan]),
dtype_equal=dtype_equal,
)
assert array_equivalent(
DatetimeIndex([0, np.nan], tz="US/Eastern"),
DatetimeIndex([0, np.nan], tz="US/Eastern"),
dtype_equal=dtype_equal,
)
assert not array_equivalent(
DatetimeIndex([0, np.nan], tz="US/Eastern"),
DatetimeIndex([1, np.nan], tz="US/Eastern"),
dtype_equal=dtype_equal,
)
# The rest are not dtype_equal
assert not array_equivalent(
DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan], tz="US/Eastern")
DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan], tz="US/Eastern"),
)
assert not array_equivalent(
DatetimeIndex([0, np.nan], tz="CET"),
Expand All @@ -353,6 +383,11 @@ def test_array_equivalent():
assert not array_equivalent(DatetimeIndex([0, np.nan]), TimedeltaIndex([0, np.nan]))


def test_array_equivalent_different_dtype_but_equal():
# Unclear if this is exposed anywhere in the public-facing API
assert array_equivalent(np.array([1, 2]), np.array([1.0, 2.0]))


@pytest.mark.parametrize(
"lvalue, rvalue",
[
Expand Down