diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index efa4a52993a90..b313478eb4985 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -304,6 +304,7 @@ Other Deprecations Performance improvements ~~~~~~~~~~~~~~~~~~~~~~~~ +- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` for objects indexed by a :class:`MultiIndex` (:issue:`55949`) - Performance improvement in :func:`concat` with ``axis=1`` and objects with unaligned indexes (:issue:`55084`) - Performance improvement in :func:`merge_asof` when ``by`` is not ``None`` (:issue:`55580`, :issue:`55678`) - Performance improvement in :func:`read_stata` for files with many variables (:issue:`55515`) diff --git a/pandas/_libs/testing.pyx b/pandas/_libs/testing.pyx index 4ba7bce51ed64..aed0f4b082d4e 100644 --- a/pandas/_libs/testing.pyx +++ b/pandas/_libs/testing.pyx @@ -78,7 +78,7 @@ cpdef assert_almost_equal(a, b, robj : str, default None Specify right object name being compared, internally used to show appropriate assertion message. - index_values : ndarray, default None + index_values : Index | ndarray, default None Specify shared index values of objects being compared, internally used to show appropriate assertion message. diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py index 5f14d46be8e70..8e49fcfb355fa 100644 --- a/pandas/_testing/asserters.py +++ b/pandas/_testing/asserters.py @@ -283,22 +283,37 @@ def _get_ilevel_values(index, level): right = cast(MultiIndex, right) for level in range(left.nlevels): - # cannot use get_level_values here because it can change dtype - llevel = _get_ilevel_values(left, level) - rlevel = _get_ilevel_values(right, level) - lobj = f"MultiIndex level [{level}]" - assert_index_equal( - llevel, - rlevel, - exact=exact, - check_names=check_names, - check_exact=check_exact, - check_categorical=check_categorical, - rtol=rtol, - atol=atol, - obj=lobj, - ) + try: + # try comparison on levels/codes to avoid densifying MultiIndex + assert_index_equal( + left.levels[level], + right.levels[level], + exact=exact, + check_names=check_names, + check_exact=check_exact, + check_categorical=check_categorical, + rtol=rtol, + atol=atol, + obj=lobj, + ) + assert_numpy_array_equal(left.codes[level], right.codes[level]) + except AssertionError: + # cannot use get_level_values here because it can change dtype + llevel = _get_ilevel_values(left, level) + rlevel = _get_ilevel_values(right, level) + + assert_index_equal( + llevel, + rlevel, + exact=exact, + check_names=check_names, + check_exact=check_exact, + check_categorical=check_categorical, + rtol=rtol, + atol=atol, + obj=lobj, + ) # get_level_values may change dtype _check_types(left.levels[level], right.levels[level], obj=obj) @@ -576,6 +591,9 @@ def raise_assert_detail( {message}""" + if isinstance(index_values, Index): + index_values = np.array(index_values) + if isinstance(index_values, np.ndarray): msg += f"\n[index]: {pprint_thing(index_values)}" @@ -630,7 +648,7 @@ def assert_numpy_array_equal( obj : str, default 'numpy array' Specify object name being compared, internally used to show appropriate assertion message. - index_values : numpy.ndarray, default None + index_values : Index | numpy.ndarray, default None optional index (shared by both left and right), used in output. """ __tracebackhide__ = True @@ -701,7 +719,7 @@ def assert_extension_array_equal( The two arrays to compare. check_dtype : bool, default True Whether to check if the ExtensionArray dtypes are identical. - index_values : numpy.ndarray, default None + index_values : Index | numpy.ndarray, default None Optional index (shared by both left and right), used in output. check_exact : bool, default False Whether to compare number exactly. @@ -932,7 +950,7 @@ def assert_series_equal( left_values, right_values, check_dtype=check_dtype, - index_values=np.asarray(left.index), + index_values=left.index, obj=str(obj), ) else: @@ -941,7 +959,7 @@ def assert_series_equal( right_values, check_dtype=check_dtype, obj=str(obj), - index_values=np.asarray(left.index), + index_values=left.index, ) elif check_datetimelike_compat and ( needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype) @@ -972,7 +990,7 @@ def assert_series_equal( atol=atol, check_dtype=bool(check_dtype), obj=str(obj), - index_values=np.asarray(left.index), + index_values=left.index, ) elif isinstance(left.dtype, ExtensionDtype) and isinstance( right.dtype, ExtensionDtype @@ -983,7 +1001,7 @@ def assert_series_equal( rtol=rtol, atol=atol, check_dtype=check_dtype, - index_values=np.asarray(left.index), + index_values=left.index, obj=str(obj), ) elif is_extension_array_dtype_and_needs_i8_conversion( @@ -993,7 +1011,7 @@ def assert_series_equal( left._values, right._values, check_dtype=check_dtype, - index_values=np.asarray(left.index), + index_values=left.index, obj=str(obj), ) elif needs_i8_conversion(left.dtype) and needs_i8_conversion(right.dtype): @@ -1002,7 +1020,7 @@ def assert_series_equal( left._values, right._values, check_dtype=check_dtype, - index_values=np.asarray(left.index), + index_values=left.index, obj=str(obj), ) else: @@ -1013,7 +1031,7 @@ def assert_series_equal( atol=atol, check_dtype=bool(check_dtype), obj=str(obj), - index_values=np.asarray(left.index), + index_values=left.index, ) # metadata comparison diff --git a/pandas/tests/frame/methods/test_value_counts.py b/pandas/tests/frame/methods/test_value_counts.py index f30db91f82b60..4136d641ef67f 100644 --- a/pandas/tests/frame/methods/test_value_counts.py +++ b/pandas/tests/frame/methods/test_value_counts.py @@ -147,7 +147,7 @@ def test_data_frame_value_counts_dropna_false(nulls_fixture): index=pd.MultiIndex( levels=[ pd.Index(["Anne", "Beth", "John"]), - pd.Index(["Louise", "Smith", nulls_fixture]), + pd.Index(["Louise", "Smith", np.nan]), ], codes=[[0, 1, 2, 2], [2, 0, 1, 2]], names=["first_name", "middle_name"],