@@ -1024,7 +1024,12 @@ def _raise(left, right, err_msg):
1024
1024
1025
1025
1026
1026
def assert_extension_array_equal (
1027
- left , right , check_dtype = True , check_less_precise = False , check_exact = False
1027
+ left ,
1028
+ right ,
1029
+ check_dtype = True ,
1030
+ check_less_precise = False ,
1031
+ check_exact = False ,
1032
+ index_values = None ,
1028
1033
):
1029
1034
"""
1030
1035
Check that left and right ExtensionArrays are equal.
@@ -1041,6 +1046,8 @@ def assert_extension_array_equal(
1041
1046
If int, then specify the digits to compare.
1042
1047
check_exact : bool, default False
1043
1048
Whether to compare number exactly.
1049
+ index_values : numpy.ndarray, default None
1050
+ optional index (shared by both left and right), used in output.
1044
1051
1045
1052
Notes
1046
1053
-----
@@ -1056,24 +1063,31 @@ def assert_extension_array_equal(
1056
1063
if hasattr (left , "asi8" ) and type (right ) == type (left ):
1057
1064
# Avoid slow object-dtype comparisons
1058
1065
# np.asarray for case where we have a np.MaskedArray
1059
- assert_numpy_array_equal (np .asarray (left .asi8 ), np .asarray (right .asi8 ))
1066
+ assert_numpy_array_equal (
1067
+ np .asarray (left .asi8 ), np .asarray (right .asi8 ), index_values = index_values
1068
+ )
1060
1069
return
1061
1070
1062
1071
left_na = np .asarray (left .isna ())
1063
1072
right_na = np .asarray (right .isna ())
1064
- assert_numpy_array_equal (left_na , right_na , obj = "ExtensionArray NA mask" )
1073
+ assert_numpy_array_equal (
1074
+ left_na , right_na , obj = "ExtensionArray NA mask" , index_values = index_values
1075
+ )
1065
1076
1066
1077
left_valid = np .asarray (left [~ left_na ].astype (object ))
1067
1078
right_valid = np .asarray (right [~ right_na ].astype (object ))
1068
1079
if check_exact :
1069
- assert_numpy_array_equal (left_valid , right_valid , obj = "ExtensionArray" )
1080
+ assert_numpy_array_equal (
1081
+ left_valid , right_valid , obj = "ExtensionArray" , index_values = index_values
1082
+ )
1070
1083
else :
1071
1084
_testing .assert_almost_equal (
1072
1085
left_valid ,
1073
1086
right_valid ,
1074
1087
check_dtype = check_dtype ,
1075
1088
check_less_precise = check_less_precise ,
1076
1089
obj = "ExtensionArray" ,
1090
+ index_values = index_values ,
1077
1091
)
1078
1092
1079
1093
@@ -1206,12 +1220,17 @@ def assert_series_equal(
1206
1220
check_less_precise = check_less_precise ,
1207
1221
check_dtype = check_dtype ,
1208
1222
obj = str (obj ),
1223
+ index_values = np .asarray (left .index ),
1209
1224
)
1210
1225
elif is_extension_array_dtype (left .dtype ) and is_extension_array_dtype (right .dtype ):
1211
- assert_extension_array_equal (left ._values , right ._values )
1226
+ assert_extension_array_equal (
1227
+ left ._values , right ._values , index_values = np .asarray (left .index )
1228
+ )
1212
1229
elif needs_i8_conversion (left .dtype ) or needs_i8_conversion (right .dtype ):
1213
1230
# DatetimeArray or TimedeltaArray
1214
- assert_extension_array_equal (left ._values , right ._values )
1231
+ assert_extension_array_equal (
1232
+ left ._values , right ._values , index_values = np .asarray (left .index )
1233
+ )
1215
1234
else :
1216
1235
_testing .assert_almost_equal (
1217
1236
left ._values ,
0 commit comments