Skip to content

Commit 00d88e9

Browse files
authored
PERF: assert_frame_equal and assert_series_equal for frames/series with a MultiIndex (#55949)
* avoid "densifying" multiindex * whatsnew * update test
1 parent 7f0b890 commit 00d88e9

File tree

4 files changed

+45
-26
lines changed

4 files changed

+45
-26
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ Other Deprecations
305305

306306
Performance improvements
307307
~~~~~~~~~~~~~~~~~~~~~~~~
308+
- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` for objects indexed by a :class:`MultiIndex` (:issue:`55949`)
308309
- Performance improvement in :func:`concat` with ``axis=1`` and objects with unaligned indexes (:issue:`55084`)
309310
- Performance improvement in :func:`merge_asof` when ``by`` is not ``None`` (:issue:`55580`, :issue:`55678`)
310311
- Performance improvement in :func:`read_stata` for files with many variables (:issue:`55515`)

pandas/_libs/testing.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ cpdef assert_almost_equal(a, b,
7878
robj : str, default None
7979
Specify right object name being compared, internally used to show
8080
appropriate assertion message.
81-
index_values : ndarray, default None
81+
index_values : Index | ndarray, default None
8282
Specify shared index values of objects being compared, internally used
8383
to show appropriate assertion message.
8484

pandas/_testing/asserters.py

+42-24
Original file line numberDiff line numberDiff line change
@@ -283,22 +283,37 @@ def _get_ilevel_values(index, level):
283283
right = cast(MultiIndex, right)
284284

285285
for level in range(left.nlevels):
286-
# cannot use get_level_values here because it can change dtype
287-
llevel = _get_ilevel_values(left, level)
288-
rlevel = _get_ilevel_values(right, level)
289-
290286
lobj = f"MultiIndex level [{level}]"
291-
assert_index_equal(
292-
llevel,
293-
rlevel,
294-
exact=exact,
295-
check_names=check_names,
296-
check_exact=check_exact,
297-
check_categorical=check_categorical,
298-
rtol=rtol,
299-
atol=atol,
300-
obj=lobj,
301-
)
287+
try:
288+
# try comparison on levels/codes to avoid densifying MultiIndex
289+
assert_index_equal(
290+
left.levels[level],
291+
right.levels[level],
292+
exact=exact,
293+
check_names=check_names,
294+
check_exact=check_exact,
295+
check_categorical=check_categorical,
296+
rtol=rtol,
297+
atol=atol,
298+
obj=lobj,
299+
)
300+
assert_numpy_array_equal(left.codes[level], right.codes[level])
301+
except AssertionError:
302+
# cannot use get_level_values here because it can change dtype
303+
llevel = _get_ilevel_values(left, level)
304+
rlevel = _get_ilevel_values(right, level)
305+
306+
assert_index_equal(
307+
llevel,
308+
rlevel,
309+
exact=exact,
310+
check_names=check_names,
311+
check_exact=check_exact,
312+
check_categorical=check_categorical,
313+
rtol=rtol,
314+
atol=atol,
315+
obj=lobj,
316+
)
302317
# get_level_values may change dtype
303318
_check_types(left.levels[level], right.levels[level], obj=obj)
304319

@@ -576,6 +591,9 @@ def raise_assert_detail(
576591
577592
{message}"""
578593

594+
if isinstance(index_values, Index):
595+
index_values = np.array(index_values)
596+
579597
if isinstance(index_values, np.ndarray):
580598
msg += f"\n[index]: {pprint_thing(index_values)}"
581599

@@ -630,7 +648,7 @@ def assert_numpy_array_equal(
630648
obj : str, default 'numpy array'
631649
Specify object name being compared, internally used to show appropriate
632650
assertion message.
633-
index_values : numpy.ndarray, default None
651+
index_values : Index | numpy.ndarray, default None
634652
optional index (shared by both left and right), used in output.
635653
"""
636654
__tracebackhide__ = True
@@ -701,7 +719,7 @@ def assert_extension_array_equal(
701719
The two arrays to compare.
702720
check_dtype : bool, default True
703721
Whether to check if the ExtensionArray dtypes are identical.
704-
index_values : numpy.ndarray, default None
722+
index_values : Index | numpy.ndarray, default None
705723
Optional index (shared by both left and right), used in output.
706724
check_exact : bool, default False
707725
Whether to compare number exactly.
@@ -932,7 +950,7 @@ def assert_series_equal(
932950
left_values,
933951
right_values,
934952
check_dtype=check_dtype,
935-
index_values=np.asarray(left.index),
953+
index_values=left.index,
936954
obj=str(obj),
937955
)
938956
else:
@@ -941,7 +959,7 @@ def assert_series_equal(
941959
right_values,
942960
check_dtype=check_dtype,
943961
obj=str(obj),
944-
index_values=np.asarray(left.index),
962+
index_values=left.index,
945963
)
946964
elif check_datetimelike_compat and (
947965
needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype)
@@ -972,7 +990,7 @@ def assert_series_equal(
972990
atol=atol,
973991
check_dtype=bool(check_dtype),
974992
obj=str(obj),
975-
index_values=np.asarray(left.index),
993+
index_values=left.index,
976994
)
977995
elif isinstance(left.dtype, ExtensionDtype) and isinstance(
978996
right.dtype, ExtensionDtype
@@ -983,7 +1001,7 @@ def assert_series_equal(
9831001
rtol=rtol,
9841002
atol=atol,
9851003
check_dtype=check_dtype,
986-
index_values=np.asarray(left.index),
1004+
index_values=left.index,
9871005
obj=str(obj),
9881006
)
9891007
elif is_extension_array_dtype_and_needs_i8_conversion(
@@ -993,7 +1011,7 @@ def assert_series_equal(
9931011
left._values,
9941012
right._values,
9951013
check_dtype=check_dtype,
996-
index_values=np.asarray(left.index),
1014+
index_values=left.index,
9971015
obj=str(obj),
9981016
)
9991017
elif needs_i8_conversion(left.dtype) and needs_i8_conversion(right.dtype):
@@ -1002,7 +1020,7 @@ def assert_series_equal(
10021020
left._values,
10031021
right._values,
10041022
check_dtype=check_dtype,
1005-
index_values=np.asarray(left.index),
1023+
index_values=left.index,
10061024
obj=str(obj),
10071025
)
10081026
else:
@@ -1013,7 +1031,7 @@ def assert_series_equal(
10131031
atol=atol,
10141032
check_dtype=bool(check_dtype),
10151033
obj=str(obj),
1016-
index_values=np.asarray(left.index),
1034+
index_values=left.index,
10171035
)
10181036

10191037
# metadata comparison

pandas/tests/frame/methods/test_value_counts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_data_frame_value_counts_dropna_false(nulls_fixture):
147147
index=pd.MultiIndex(
148148
levels=[
149149
pd.Index(["Anne", "Beth", "John"]),
150-
pd.Index(["Louise", "Smith", nulls_fixture]),
150+
pd.Index(["Louise", "Smith", np.nan]),
151151
],
152152
codes=[[0, 1, 2, 2], [2, 0, 1, 2]],
153153
names=["first_name", "middle_name"],

0 commit comments

Comments
 (0)