Skip to content

Commit c9faed4

Browse files
authored
ENH: Added index to output of assert_series_equal on categorical and datetime values (#33575)
1 parent dc57f28 commit c9faed4

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

pandas/_testing.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,12 @@ def _raise(left, right, err_msg):
10241024

10251025

10261026
def assert_extension_array_equal(
1027-
left, right, check_dtype=True, check_less_precise=False, check_exact=False
1027+
left,
1028+
right,
1029+
check_dtype=True,
1030+
check_less_precise=False,
1031+
check_exact=False,
1032+
index_values=None,
10281033
):
10291034
"""
10301035
Check that left and right ExtensionArrays are equal.
@@ -1041,6 +1046,8 @@ def assert_extension_array_equal(
10411046
If int, then specify the digits to compare.
10421047
check_exact : bool, default False
10431048
Whether to compare number exactly.
1049+
index_values : numpy.ndarray, default None
1050+
optional index (shared by both left and right), used in output.
10441051
10451052
Notes
10461053
-----
@@ -1056,24 +1063,31 @@ def assert_extension_array_equal(
10561063
if hasattr(left, "asi8") and type(right) == type(left):
10571064
# Avoid slow object-dtype comparisons
10581065
# np.asarray for case where we have a np.MaskedArray
1059-
assert_numpy_array_equal(np.asarray(left.asi8), np.asarray(right.asi8))
1066+
assert_numpy_array_equal(
1067+
np.asarray(left.asi8), np.asarray(right.asi8), index_values=index_values
1068+
)
10601069
return
10611070

10621071
left_na = np.asarray(left.isna())
10631072
right_na = np.asarray(right.isna())
1064-
assert_numpy_array_equal(left_na, right_na, obj="ExtensionArray NA mask")
1073+
assert_numpy_array_equal(
1074+
left_na, right_na, obj="ExtensionArray NA mask", index_values=index_values
1075+
)
10651076

10661077
left_valid = np.asarray(left[~left_na].astype(object))
10671078
right_valid = np.asarray(right[~right_na].astype(object))
10681079
if check_exact:
1069-
assert_numpy_array_equal(left_valid, right_valid, obj="ExtensionArray")
1080+
assert_numpy_array_equal(
1081+
left_valid, right_valid, obj="ExtensionArray", index_values=index_values
1082+
)
10701083
else:
10711084
_testing.assert_almost_equal(
10721085
left_valid,
10731086
right_valid,
10741087
check_dtype=check_dtype,
10751088
check_less_precise=check_less_precise,
10761089
obj="ExtensionArray",
1090+
index_values=index_values,
10771091
)
10781092

10791093

@@ -1206,12 +1220,17 @@ def assert_series_equal(
12061220
check_less_precise=check_less_precise,
12071221
check_dtype=check_dtype,
12081222
obj=str(obj),
1223+
index_values=np.asarray(left.index),
12091224
)
12101225
elif is_extension_array_dtype(left.dtype) and is_extension_array_dtype(right.dtype):
1211-
assert_extension_array_equal(left._values, right._values)
1226+
assert_extension_array_equal(
1227+
left._values, right._values, index_values=np.asarray(left.index)
1228+
)
12121229
elif needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype):
12131230
# DatetimeArray or TimedeltaArray
1214-
assert_extension_array_equal(left._values, right._values)
1231+
assert_extension_array_equal(
1232+
left._values, right._values, index_values=np.asarray(left.index)
1233+
)
12151234
else:
12161235
_testing.assert_almost_equal(
12171236
left._values,

pandas/tests/util/test_assert_series_equal.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_series_equal_length_mismatch(check_less_precise):
165165
tm.assert_series_equal(s1, s2, check_less_precise=check_less_precise)
166166

167167

168-
def test_series_equal_values_mismatch(check_less_precise):
168+
def test_series_equal_numeric_values_mismatch(check_less_precise):
169169
msg = """Series are different
170170
171171
Series values are different \\(33\\.33333 %\\)
@@ -180,6 +180,38 @@ def test_series_equal_values_mismatch(check_less_precise):
180180
tm.assert_series_equal(s1, s2, check_less_precise=check_less_precise)
181181

182182

183+
def test_series_equal_categorical_values_mismatch(check_less_precise):
184+
msg = """Series are different
185+
186+
Series values are different \\(66\\.66667 %\\)
187+
\\[index\\]: \\[0, 1, 2\\]
188+
\\[left\\]: \\[a, b, c\\]
189+
Categories \\(3, object\\): \\[a, b, c\\]
190+
\\[right\\]: \\[a, c, b\\]
191+
Categories \\(3, object\\): \\[a, b, c\\]"""
192+
193+
s1 = Series(Categorical(["a", "b", "c"]))
194+
s2 = Series(Categorical(["a", "c", "b"]))
195+
196+
with pytest.raises(AssertionError, match=msg):
197+
tm.assert_series_equal(s1, s2, check_less_precise=check_less_precise)
198+
199+
200+
def test_series_equal_datetime_values_mismatch(check_less_precise):
201+
msg = """numpy array are different
202+
203+
numpy array values are different \\(100.0 %\\)
204+
\\[index\\]: \\[0, 1, 2\\]
205+
\\[left\\]: \\[1514764800000000000, 1514851200000000000, 1514937600000000000\\]
206+
\\[right\\]: \\[1549065600000000000, 1549152000000000000, 1549238400000000000\\]"""
207+
208+
s1 = Series(pd.date_range("2018-01-01", periods=3, freq="D"))
209+
s2 = Series(pd.date_range("2019-02-02", periods=3, freq="D"))
210+
211+
with pytest.raises(AssertionError, match=msg):
212+
tm.assert_series_equal(s1, s2, check_less_precise=check_less_precise)
213+
214+
183215
def test_series_equal_categorical_mismatch(check_categorical):
184216
msg = """Attributes of Series are different
185217

0 commit comments

Comments
 (0)