|
11 | 11 |
|
12 | 12 | import pandas as pd
|
13 | 13 | from pandas import DataFrame, Series, compat
|
| 14 | +from pandas.core.arrays.sparse import SparseArray |
14 | 15 | import pandas.util.testing as tm
|
15 | 16 | from pandas.util.testing import (
|
16 |
| - RNGContext, assert_almost_equal, assert_frame_equal, assert_index_equal, |
17 |
| - assert_numpy_array_equal, assert_series_equal) |
| 17 | + RNGContext, assert_almost_equal, assert_extension_array_equal, |
| 18 | + assert_frame_equal, assert_index_equal, assert_numpy_array_equal, |
| 19 | + assert_series_equal) |
18 | 20 |
|
19 | 21 |
|
20 | 22 | class TestAssertAlmostEqual(object):
|
@@ -850,6 +852,92 @@ def test_interval_array_equal_message(self):
|
850 | 852 | tm.assert_interval_array_equal(a, b)
|
851 | 853 |
|
852 | 854 |
|
| 855 | +class TestAssertExtensionArrayEqual(object): |
| 856 | + |
| 857 | + def test_check_exact(self): |
| 858 | + # GH 23709 |
| 859 | + left = SparseArray([-0.17387645482451206, 0.3414148016424936]) |
| 860 | + right = SparseArray([-0.17387645482451206, 0.3414148016424937]) |
| 861 | + |
| 862 | + # passes with check_exact=False (should be default) |
| 863 | + assert_extension_array_equal(left, right) |
| 864 | + assert_extension_array_equal(left, right, check_exact=False) |
| 865 | + |
| 866 | + # raises with check_exact=True |
| 867 | + msg = textwrap.dedent("""\ |
| 868 | + ExtensionArray are different |
| 869 | +
|
| 870 | + ExtensionArray values are different \\(50\\.0 %\\) |
| 871 | + \\[left\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\] |
| 872 | + \\[right\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\]""") |
| 873 | + with pytest.raises(AssertionError, match=msg): |
| 874 | + assert_extension_array_equal(left, right, check_exact=True) |
| 875 | + |
| 876 | + @pytest.mark.parametrize('check_less_precise', [True, 0, 1, 2, 3, 4]) |
| 877 | + def test_check_less_precise_passes(self, check_less_precise): |
| 878 | + left = SparseArray([0.5, 0.123456]) |
| 879 | + right = SparseArray([0.5, 0.123457]) |
| 880 | + assert_extension_array_equal( |
| 881 | + left, right, check_less_precise=check_less_precise) |
| 882 | + |
| 883 | + @pytest.mark.parametrize('check_less_precise', [False, 5, 6, 7, 8, 9]) |
| 884 | + def test_check_less_precise_fails(self, check_less_precise): |
| 885 | + left = SparseArray([0.5, 0.123456]) |
| 886 | + right = SparseArray([0.5, 0.123457]) |
| 887 | + |
| 888 | + msg = textwrap.dedent("""\ |
| 889 | + ExtensionArray are different |
| 890 | +
|
| 891 | + ExtensionArray values are different \\(50\\.0 %\\) |
| 892 | + \\[left\\]: \\[0\\.5, 0\\.123456\\] |
| 893 | + \\[right\\]: \\[0\\.5, 0\\.123457\\]""") |
| 894 | + with pytest.raises(AssertionError, match=msg): |
| 895 | + assert_extension_array_equal( |
| 896 | + left, right, check_less_precise=check_less_precise) |
| 897 | + |
| 898 | + def test_check_dtype(self): |
| 899 | + left = SparseArray(np.arange(5, dtype='int64')) |
| 900 | + right = SparseArray(np.arange(5, dtype='int32')) |
| 901 | + |
| 902 | + # passes with check_dtype=False |
| 903 | + assert_extension_array_equal(left, right, check_dtype=False) |
| 904 | + |
| 905 | + # raises with check_dtype=True |
| 906 | + msg = textwrap.dedent("""\ |
| 907 | + ExtensionArray are different |
| 908 | +
|
| 909 | + Attribute "dtype" are different |
| 910 | + \\[left\\]: Sparse\\[int64, 0\\] |
| 911 | + \\[right\\]: Sparse\\[int32, 0\\]""") |
| 912 | + with pytest.raises(AssertionError, match=msg): |
| 913 | + assert_extension_array_equal(left, right, check_dtype=True) |
| 914 | + |
| 915 | + def test_missing_values(self): |
| 916 | + left = SparseArray([np.nan, 1, 2, np.nan]) |
| 917 | + right = SparseArray([np.nan, 1, 2, 3]) |
| 918 | + |
| 919 | + msg = textwrap.dedent("""\ |
| 920 | + ExtensionArray NA mask are different |
| 921 | +
|
| 922 | + ExtensionArray NA mask values are different \\(25\\.0 %\\) |
| 923 | + \\[left\\]: \\[True, False, False, True\\] |
| 924 | + \\[right\\]: \\[True, False, False, False\\]""") |
| 925 | + with pytest.raises(AssertionError, match=msg): |
| 926 | + assert_extension_array_equal(left, right) |
| 927 | + |
| 928 | + def test_non_extension_array(self): |
| 929 | + numpy_array = np.arange(5) |
| 930 | + extension_array = SparseArray(np.arange(5)) |
| 931 | + |
| 932 | + msg = 'left is not an ExtensionArray' |
| 933 | + with pytest.raises(AssertionError, match=msg): |
| 934 | + assert_extension_array_equal(numpy_array, extension_array) |
| 935 | + |
| 936 | + msg = 'right is not an ExtensionArray' |
| 937 | + with pytest.raises(AssertionError, match=msg): |
| 938 | + assert_extension_array_equal(extension_array, numpy_array) |
| 939 | + |
| 940 | + |
853 | 941 | class TestRNGContext(object):
|
854 | 942 |
|
855 | 943 | def test_RNGContext(self):
|
|
0 commit comments