Skip to content

Commit 32ebcfc

Browse files
authored
PERF: assert_frame_equal / assert_series_equal (#55971)
* improve perf of index assertions * whatsnew * faster _array_equivalent_object * add comment * remove xfail * skip mask if not needed
1 parent dbf8aaf commit 32ebcfc

File tree

4 files changed

+30
-28
lines changed

4 files changed

+30
-28
lines changed

doc/source/whatsnew/v2.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ Other Deprecations
309309

310310
Performance improvements
311311
~~~~~~~~~~~~~~~~~~~~~~~~
312-
- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` for objects indexed by a :class:`MultiIndex` (:issue:`55949`)
312+
- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` (:issue:`55949`, :issue:`55971`)
313313
- Performance improvement in :func:`concat` with ``axis=1`` and objects with unaligned indexes (:issue:`55084`)
314314
- Performance improvement in :func:`merge_asof` when ``by`` is not ``None`` (:issue:`55580`, :issue:`55678`)
315315
- Performance improvement in :func:`read_stata` for files with many variables (:issue:`55515`)

pandas/_testing/asserters.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
Series,
4444
TimedeltaIndex,
4545
)
46-
from pandas.core.algorithms import take_nd
4746
from pandas.core.arrays import (
4847
DatetimeArray,
4948
ExtensionArray,
@@ -246,13 +245,6 @@ def _check_types(left, right, obj: str = "Index") -> None:
246245

247246
assert_attr_equal("dtype", left, right, obj=obj)
248247

249-
def _get_ilevel_values(index, level):
250-
# accept level number only
251-
unique = index.levels[level]
252-
level_codes = index.codes[level]
253-
filled = take_nd(unique._values, level_codes, fill_value=unique._na_value)
254-
return unique._shallow_copy(filled, name=index.names[level])
255-
256248
# instance validation
257249
_check_isinstance(left, right, Index)
258250

@@ -299,9 +291,8 @@ def _get_ilevel_values(index, level):
299291
)
300292
assert_numpy_array_equal(left.codes[level], right.codes[level])
301293
except AssertionError:
302-
# cannot use get_level_values here because it can change dtype
303-
llevel = _get_ilevel_values(left, level)
304-
rlevel = _get_ilevel_values(right, level)
294+
llevel = left.get_level_values(level)
295+
rlevel = right.get_level_values(level)
305296

306297
assert_index_equal(
307298
llevel,
@@ -592,7 +583,7 @@ def raise_assert_detail(
592583
{message}"""
593584

594585
if isinstance(index_values, Index):
595-
index_values = np.array(index_values)
586+
index_values = np.asarray(index_values)
596587

597588
if isinstance(index_values, np.ndarray):
598589
msg += f"\n[index]: {pprint_thing(index_values)}"

pandas/core/dtypes/missing.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -562,12 +562,29 @@ def _array_equivalent_datetimelike(left: np.ndarray, right: np.ndarray):
562562

563563

564564
def _array_equivalent_object(left: np.ndarray, right: np.ndarray, strict_nan: bool):
565-
if not strict_nan:
566-
# isna considers NaN and None to be equivalent.
567-
568-
return lib.array_equivalent_object(ensure_object(left), ensure_object(right))
569-
570-
for left_value, right_value in zip(left, right):
565+
left = ensure_object(left)
566+
right = ensure_object(right)
567+
568+
mask: npt.NDArray[np.bool_] | None = None
569+
if strict_nan:
570+
mask = isna(left) & isna(right)
571+
if not mask.any():
572+
mask = None
573+
574+
try:
575+
if mask is None:
576+
return lib.array_equivalent_object(left, right)
577+
if not lib.array_equivalent_object(left[~mask], right[~mask]):
578+
return False
579+
left_remaining = left[mask]
580+
right_remaining = right[mask]
581+
except ValueError:
582+
# can raise a ValueError if left and right cannot be
583+
# compared (e.g. nested arrays)
584+
left_remaining = left
585+
right_remaining = right
586+
587+
for left_value, right_value in zip(left_remaining, right_remaining):
571588
if left_value is NaT and right_value is not NaT:
572589
return False
573590

pandas/tests/dtypes/test_missing.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -560,9 +560,7 @@ def test_array_equivalent_str(dtype):
560560
)
561561

562562

563-
@pytest.mark.parametrize(
564-
"strict_nan", [pytest.param(True, marks=pytest.mark.xfail), False]
565-
)
563+
@pytest.mark.parametrize("strict_nan", [True, False])
566564
def test_array_equivalent_nested(strict_nan):
567565
# reached in groupby aggregations, make sure we use np.any when checking
568566
# if the comparison is truthy
@@ -585,9 +583,7 @@ def test_array_equivalent_nested(strict_nan):
585583

586584

587585
@pytest.mark.filterwarnings("ignore:elementwise comparison failed:DeprecationWarning")
588-
@pytest.mark.parametrize(
589-
"strict_nan", [pytest.param(True, marks=pytest.mark.xfail), False]
590-
)
586+
@pytest.mark.parametrize("strict_nan", [True, False])
591587
def test_array_equivalent_nested2(strict_nan):
592588
# more than one level of nesting
593589
left = np.array(
@@ -612,9 +608,7 @@ def test_array_equivalent_nested2(strict_nan):
612608
assert not array_equivalent(left, right, strict_nan=strict_nan)
613609

614610

615-
@pytest.mark.parametrize(
616-
"strict_nan", [pytest.param(True, marks=pytest.mark.xfail), False]
617-
)
611+
@pytest.mark.parametrize("strict_nan", [True, False])
618612
def test_array_equivalent_nested_list(strict_nan):
619613
left = np.array([[50, 70, 90], [20, 30]], dtype=object)
620614
right = np.array([[50, 70, 90], [20, 30]], dtype=object)

0 commit comments

Comments
 (0)