From 135ea8778d7197ea399825afecf45f7775ccd995 Mon Sep 17 00:00:00 2001 From: Anthony Milbourne <18662115+amilbourne@users.noreply.github.com> Date: Wed, 15 Apr 2020 22:59:22 +0100 Subject: [PATCH] ENH: Added index to output of assert_series_equal on categorical and datetime values --- pandas/_testing.py | 31 +++++++++++++---- pandas/tests/util/test_assert_series_equal.py | 34 ++++++++++++++++++- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/pandas/_testing.py b/pandas/_testing.py index 1f6b645c821c8..d2b48b54e8ab0 100644 --- a/pandas/_testing.py +++ b/pandas/_testing.py @@ -999,7 +999,12 @@ def _raise(left, right, err_msg): def assert_extension_array_equal( - left, right, check_dtype=True, check_less_precise=False, check_exact=False + left, + right, + check_dtype=True, + check_less_precise=False, + check_exact=False, + index_values=None, ): """ Check that left and right ExtensionArrays are equal. @@ -1016,6 +1021,8 @@ def assert_extension_array_equal( If int, then specify the digits to compare. check_exact : bool, default False Whether to compare number exactly. + index_values : numpy.ndarray, default None + optional index (shared by both left and right), used in output. Notes ----- @@ -1031,17 +1038,23 @@ def assert_extension_array_equal( if hasattr(left, "asi8") and type(right) == type(left): # Avoid slow object-dtype comparisons # np.asarray for case where we have a np.MaskedArray - assert_numpy_array_equal(np.asarray(left.asi8), np.asarray(right.asi8)) + assert_numpy_array_equal( + np.asarray(left.asi8), np.asarray(right.asi8), index_values=index_values + ) return left_na = np.asarray(left.isna()) right_na = np.asarray(right.isna()) - assert_numpy_array_equal(left_na, right_na, obj="ExtensionArray NA mask") + assert_numpy_array_equal( + left_na, right_na, obj="ExtensionArray NA mask", index_values=index_values + ) left_valid = np.asarray(left[~left_na].astype(object)) right_valid = np.asarray(right[~right_na].astype(object)) if check_exact: - assert_numpy_array_equal(left_valid, right_valid, obj="ExtensionArray") + assert_numpy_array_equal( + left_valid, right_valid, obj="ExtensionArray", index_values=index_values + ) else: _testing.assert_almost_equal( left_valid, @@ -1049,6 +1062,7 @@ def assert_extension_array_equal( check_dtype=check_dtype, check_less_precise=check_less_precise, obj="ExtensionArray", + index_values=index_values, ) @@ -1181,12 +1195,17 @@ def assert_series_equal( check_less_precise=check_less_precise, check_dtype=check_dtype, obj=str(obj), + index_values=np.asarray(left.index), ) elif is_extension_array_dtype(left.dtype) and is_extension_array_dtype(right.dtype): - assert_extension_array_equal(left._values, right._values) + assert_extension_array_equal( + left._values, right._values, index_values=np.asarray(left.index) + ) elif needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype): # DatetimeArray or TimedeltaArray - assert_extension_array_equal(left._values, right._values) + assert_extension_array_equal( + left._values, right._values, index_values=np.asarray(left.index) + ) else: _testing.assert_almost_equal( left._values, diff --git a/pandas/tests/util/test_assert_series_equal.py b/pandas/tests/util/test_assert_series_equal.py index 8bf3d82672695..337a06b91e443 100644 --- a/pandas/tests/util/test_assert_series_equal.py +++ b/pandas/tests/util/test_assert_series_equal.py @@ -165,7 +165,7 @@ def test_series_equal_length_mismatch(check_less_precise): tm.assert_series_equal(s1, s2, check_less_precise=check_less_precise) -def test_series_equal_values_mismatch(check_less_precise): +def test_series_equal_numeric_values_mismatch(check_less_precise): msg = """Series are different Series values are different \\(33\\.33333 %\\) @@ -180,6 +180,38 @@ def test_series_equal_values_mismatch(check_less_precise): tm.assert_series_equal(s1, s2, check_less_precise=check_less_precise) +def test_series_equal_categorical_values_mismatch(check_less_precise): + msg = """Series are different + +Series values are different \\(66\\.66667 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[a, b, c\\] +Categories \\(3, object\\): \\[a, b, c\\] +\\[right\\]: \\[a, c, b\\] +Categories \\(3, object\\): \\[a, b, c\\]""" + + s1 = Series(Categorical(["a", "b", "c"])) + s2 = Series(Categorical(["a", "c", "b"])) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, check_less_precise=check_less_precise) + + +def test_series_equal_datetime_values_mismatch(check_less_precise): + msg = """numpy array are different + +numpy array values are different \\(100.0 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[1514764800000000000, 1514851200000000000, 1514937600000000000\\] +\\[right\\]: \\[1549065600000000000, 1549152000000000000, 1549238400000000000\\]""" + + s1 = Series(pd.date_range("2018-01-01", periods=3, freq="D")) + s2 = Series(pd.date_range("2019-02-02", periods=3, freq="D")) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, check_less_precise=check_less_precise) + + def test_series_equal_categorical_mismatch(check_categorical): msg = """Attributes of Series are different