@@ -283,22 +283,37 @@ def _get_ilevel_values(index, level):
283
283
right = cast (MultiIndex , right )
284
284
285
285
for level in range (left .nlevels ):
286
- # cannot use get_level_values here because it can change dtype
287
- llevel = _get_ilevel_values (left , level )
288
- rlevel = _get_ilevel_values (right , level )
289
-
290
286
lobj = f"MultiIndex level [{ level } ]"
291
- assert_index_equal (
292
- llevel ,
293
- rlevel ,
294
- exact = exact ,
295
- check_names = check_names ,
296
- check_exact = check_exact ,
297
- check_categorical = check_categorical ,
298
- rtol = rtol ,
299
- atol = atol ,
300
- obj = lobj ,
301
- )
287
+ try :
288
+ # try comparison on levels/codes to avoid densifying MultiIndex
289
+ assert_index_equal (
290
+ left .levels [level ],
291
+ right .levels [level ],
292
+ exact = exact ,
293
+ check_names = check_names ,
294
+ check_exact = check_exact ,
295
+ check_categorical = check_categorical ,
296
+ rtol = rtol ,
297
+ atol = atol ,
298
+ obj = lobj ,
299
+ )
300
+ assert_numpy_array_equal (left .codes [level ], right .codes [level ])
301
+ except AssertionError :
302
+ # cannot use get_level_values here because it can change dtype
303
+ llevel = _get_ilevel_values (left , level )
304
+ rlevel = _get_ilevel_values (right , level )
305
+
306
+ assert_index_equal (
307
+ llevel ,
308
+ rlevel ,
309
+ exact = exact ,
310
+ check_names = check_names ,
311
+ check_exact = check_exact ,
312
+ check_categorical = check_categorical ,
313
+ rtol = rtol ,
314
+ atol = atol ,
315
+ obj = lobj ,
316
+ )
302
317
# get_level_values may change dtype
303
318
_check_types (left .levels [level ], right .levels [level ], obj = obj )
304
319
@@ -576,6 +591,9 @@ def raise_assert_detail(
576
591
577
592
{ message } """
578
593
594
+ if isinstance (index_values , Index ):
595
+ index_values = np .array (index_values )
596
+
579
597
if isinstance (index_values , np .ndarray ):
580
598
msg += f"\n [index]: { pprint_thing (index_values )} "
581
599
@@ -630,7 +648,7 @@ def assert_numpy_array_equal(
630
648
obj : str, default 'numpy array'
631
649
Specify object name being compared, internally used to show appropriate
632
650
assertion message.
633
- index_values : numpy.ndarray, default None
651
+ index_values : Index | numpy.ndarray, default None
634
652
optional index (shared by both left and right), used in output.
635
653
"""
636
654
__tracebackhide__ = True
@@ -701,7 +719,7 @@ def assert_extension_array_equal(
701
719
The two arrays to compare.
702
720
check_dtype : bool, default True
703
721
Whether to check if the ExtensionArray dtypes are identical.
704
- index_values : numpy.ndarray, default None
722
+ index_values : Index | numpy.ndarray, default None
705
723
Optional index (shared by both left and right), used in output.
706
724
check_exact : bool, default False
707
725
Whether to compare number exactly.
@@ -932,7 +950,7 @@ def assert_series_equal(
932
950
left_values ,
933
951
right_values ,
934
952
check_dtype = check_dtype ,
935
- index_values = np . asarray ( left .index ) ,
953
+ index_values = left .index ,
936
954
obj = str (obj ),
937
955
)
938
956
else :
@@ -941,7 +959,7 @@ def assert_series_equal(
941
959
right_values ,
942
960
check_dtype = check_dtype ,
943
961
obj = str (obj ),
944
- index_values = np . asarray ( left .index ) ,
962
+ index_values = left .index ,
945
963
)
946
964
elif check_datetimelike_compat and (
947
965
needs_i8_conversion (left .dtype ) or needs_i8_conversion (right .dtype )
@@ -972,7 +990,7 @@ def assert_series_equal(
972
990
atol = atol ,
973
991
check_dtype = bool (check_dtype ),
974
992
obj = str (obj ),
975
- index_values = np . asarray ( left .index ) ,
993
+ index_values = left .index ,
976
994
)
977
995
elif isinstance (left .dtype , ExtensionDtype ) and isinstance (
978
996
right .dtype , ExtensionDtype
@@ -983,7 +1001,7 @@ def assert_series_equal(
983
1001
rtol = rtol ,
984
1002
atol = atol ,
985
1003
check_dtype = check_dtype ,
986
- index_values = np . asarray ( left .index ) ,
1004
+ index_values = left .index ,
987
1005
obj = str (obj ),
988
1006
)
989
1007
elif is_extension_array_dtype_and_needs_i8_conversion (
@@ -993,7 +1011,7 @@ def assert_series_equal(
993
1011
left ._values ,
994
1012
right ._values ,
995
1013
check_dtype = check_dtype ,
996
- index_values = np . asarray ( left .index ) ,
1014
+ index_values = left .index ,
997
1015
obj = str (obj ),
998
1016
)
999
1017
elif needs_i8_conversion (left .dtype ) and needs_i8_conversion (right .dtype ):
@@ -1002,7 +1020,7 @@ def assert_series_equal(
1002
1020
left ._values ,
1003
1021
right ._values ,
1004
1022
check_dtype = check_dtype ,
1005
- index_values = np . asarray ( left .index ) ,
1023
+ index_values = left .index ,
1006
1024
obj = str (obj ),
1007
1025
)
1008
1026
else :
@@ -1013,7 +1031,7 @@ def assert_series_equal(
1013
1031
atol = atol ,
1014
1032
check_dtype = bool (check_dtype ),
1015
1033
obj = str (obj ),
1016
- index_values = np . asarray ( left .index ) ,
1034
+ index_values = left .index ,
1017
1035
)
1018
1036
1019
1037
# metadata comparison
0 commit comments