diff --git a/pandas/tests/util/test_testing.py b/pandas/tests/util/test_testing.py index a886579ee913f..e649cea14ec39 100644 --- a/pandas/tests/util/test_testing.py +++ b/pandas/tests/util/test_testing.py @@ -11,10 +11,12 @@ import pandas as pd from pandas import DataFrame, Series, compat +from pandas.core.arrays.sparse import SparseArray import pandas.util.testing as tm from pandas.util.testing import ( - RNGContext, assert_almost_equal, assert_frame_equal, assert_index_equal, - assert_numpy_array_equal, assert_series_equal) + RNGContext, assert_almost_equal, assert_extension_array_equal, + assert_frame_equal, assert_index_equal, assert_numpy_array_equal, + assert_series_equal) class TestAssertAlmostEqual(object): @@ -850,6 +852,92 @@ def test_interval_array_equal_message(self): tm.assert_interval_array_equal(a, b) +class TestAssertExtensionArrayEqual(object): + + def test_check_exact(self): + # GH 23709 + left = SparseArray([-0.17387645482451206, 0.3414148016424936]) + right = SparseArray([-0.17387645482451206, 0.3414148016424937]) + + # passes with check_exact=False (should be default) + assert_extension_array_equal(left, right) + assert_extension_array_equal(left, right, check_exact=False) + + # raises with check_exact=True + msg = textwrap.dedent("""\ + ExtensionArray are different + + ExtensionArray values are different \\(50\\.0 %\\) + \\[left\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\] + \\[right\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\]""") + with pytest.raises(AssertionError, match=msg): + assert_extension_array_equal(left, right, check_exact=True) + + @pytest.mark.parametrize('check_less_precise', [True, 0, 1, 2, 3, 4]) + def test_check_less_precise_passes(self, check_less_precise): + left = SparseArray([0.5, 0.123456]) + right = SparseArray([0.5, 0.123457]) + assert_extension_array_equal( + left, right, check_less_precise=check_less_precise) + + @pytest.mark.parametrize('check_less_precise', [False, 5, 6, 7, 8, 9]) + def test_check_less_precise_fails(self, check_less_precise): + left = SparseArray([0.5, 0.123456]) + right = SparseArray([0.5, 0.123457]) + + msg = textwrap.dedent("""\ + ExtensionArray are different + + ExtensionArray values are different \\(50\\.0 %\\) + \\[left\\]: \\[0\\.5, 0\\.123456\\] + \\[right\\]: \\[0\\.5, 0\\.123457\\]""") + with pytest.raises(AssertionError, match=msg): + assert_extension_array_equal( + left, right, check_less_precise=check_less_precise) + + def test_check_dtype(self): + left = SparseArray(np.arange(5, dtype='int64')) + right = SparseArray(np.arange(5, dtype='int32')) + + # passes with check_dtype=False + assert_extension_array_equal(left, right, check_dtype=False) + + # raises with check_dtype=True + msg = textwrap.dedent("""\ + ExtensionArray are different + + Attribute "dtype" are different + \\[left\\]: Sparse\\[int64, 0\\] + \\[right\\]: Sparse\\[int32, 0\\]""") + with pytest.raises(AssertionError, match=msg): + assert_extension_array_equal(left, right, check_dtype=True) + + def test_missing_values(self): + left = SparseArray([np.nan, 1, 2, np.nan]) + right = SparseArray([np.nan, 1, 2, 3]) + + msg = textwrap.dedent("""\ + ExtensionArray NA mask are different + + ExtensionArray NA mask values are different \\(25\\.0 %\\) + \\[left\\]: \\[True, False, False, True\\] + \\[right\\]: \\[True, False, False, False\\]""") + with pytest.raises(AssertionError, match=msg): + assert_extension_array_equal(left, right) + + def test_non_extension_array(self): + numpy_array = np.arange(5) + extension_array = SparseArray(np.arange(5)) + + msg = 'left is not an ExtensionArray' + with pytest.raises(AssertionError, match=msg): + assert_extension_array_equal(numpy_array, extension_array) + + msg = 'right is not an ExtensionArray' + with pytest.raises(AssertionError, match=msg): + assert_extension_array_equal(extension_array, numpy_array) + + class TestRNGContext(object): def test_RNGContext(self): diff --git a/pandas/util/testing.py b/pandas/util/testing.py index 1fa77f5321038..3bde83b2793ad 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -1189,13 +1189,23 @@ def _raise(left, right, err_msg): return True -def assert_extension_array_equal(left, right): +def assert_extension_array_equal(left, right, check_dtype=True, + check_less_precise=False, + check_exact=False): """Check that left and right ExtensionArrays are equal. Parameters ---------- left, right : ExtensionArray The two arrays to compare + check_dtype : bool, default True + Whether to check if the ExtensionArray dtypes are identical. + check_less_precise : bool or int, default False + Specify comparison precision. Only used when check_exact is False. + 5 digits (False) or 3 digits (True) after decimal points are compared. + If int, then specify the digits to compare. + check_exact : bool, default False + Whether to compare number exactly. Notes ----- @@ -1203,17 +1213,24 @@ def assert_extension_array_equal(left, right): A mask of missing values is computed for each and checked to match. The remaining all-valid values are cast to object dtype and checked. """ - assert isinstance(left, ExtensionArray) - assert left.dtype == right.dtype + assert isinstance(left, ExtensionArray), 'left is not an ExtensionArray' + assert isinstance(right, ExtensionArray), 'right is not an ExtensionArray' + if check_dtype: + assert_attr_equal('dtype', left, right, obj='ExtensionArray') + left_na = np.asarray(left.isna()) right_na = np.asarray(right.isna()) - - assert_numpy_array_equal(left_na, right_na) + assert_numpy_array_equal(left_na, right_na, obj='ExtensionArray NA mask') left_valid = np.asarray(left[~left_na].astype(object)) right_valid = np.asarray(right[~right_na].astype(object)) - - assert_numpy_array_equal(left_valid, right_valid) + if check_exact: + assert_numpy_array_equal(left_valid, right_valid, obj='ExtensionArray') + else: + _testing.assert_almost_equal(left_valid, right_valid, + check_dtype=check_dtype, + check_less_precise=check_less_precise, + obj='ExtensionArray') # This could be refactored to use the NDFrame.equals method