Skip to content

Commit 2cef406

Browse files
jschendelPingviinituutti
authored andcommitted
TST: Make assert_extension_array_equal behavior consistent (pandas-dev#23808)
1 parent d64c5c2 commit 2cef406

File tree

2 files changed

+114
-9
lines changed

2 files changed

+114
-9
lines changed

pandas/tests/util/test_testing.py

+90-2
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111

1212
import pandas as pd
1313
from pandas import DataFrame, Series, compat
14+
from pandas.core.arrays.sparse import SparseArray
1415
import pandas.util.testing as tm
1516
from pandas.util.testing import (
16-
RNGContext, assert_almost_equal, assert_frame_equal, assert_index_equal,
17-
assert_numpy_array_equal, assert_series_equal)
17+
RNGContext, assert_almost_equal, assert_extension_array_equal,
18+
assert_frame_equal, assert_index_equal, assert_numpy_array_equal,
19+
assert_series_equal)
1820

1921

2022
class TestAssertAlmostEqual(object):
@@ -850,6 +852,92 @@ def test_interval_array_equal_message(self):
850852
tm.assert_interval_array_equal(a, b)
851853

852854

855+
class TestAssertExtensionArrayEqual(object):
856+
857+
def test_check_exact(self):
858+
# GH 23709
859+
left = SparseArray([-0.17387645482451206, 0.3414148016424936])
860+
right = SparseArray([-0.17387645482451206, 0.3414148016424937])
861+
862+
# passes with check_exact=False (should be default)
863+
assert_extension_array_equal(left, right)
864+
assert_extension_array_equal(left, right, check_exact=False)
865+
866+
# raises with check_exact=True
867+
msg = textwrap.dedent("""\
868+
ExtensionArray are different
869+
870+
ExtensionArray values are different \\(50\\.0 %\\)
871+
\\[left\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\]
872+
\\[right\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\]""")
873+
with pytest.raises(AssertionError, match=msg):
874+
assert_extension_array_equal(left, right, check_exact=True)
875+
876+
@pytest.mark.parametrize('check_less_precise', [True, 0, 1, 2, 3, 4])
877+
def test_check_less_precise_passes(self, check_less_precise):
878+
left = SparseArray([0.5, 0.123456])
879+
right = SparseArray([0.5, 0.123457])
880+
assert_extension_array_equal(
881+
left, right, check_less_precise=check_less_precise)
882+
883+
@pytest.mark.parametrize('check_less_precise', [False, 5, 6, 7, 8, 9])
884+
def test_check_less_precise_fails(self, check_less_precise):
885+
left = SparseArray([0.5, 0.123456])
886+
right = SparseArray([0.5, 0.123457])
887+
888+
msg = textwrap.dedent("""\
889+
ExtensionArray are different
890+
891+
ExtensionArray values are different \\(50\\.0 %\\)
892+
\\[left\\]: \\[0\\.5, 0\\.123456\\]
893+
\\[right\\]: \\[0\\.5, 0\\.123457\\]""")
894+
with pytest.raises(AssertionError, match=msg):
895+
assert_extension_array_equal(
896+
left, right, check_less_precise=check_less_precise)
897+
898+
def test_check_dtype(self):
899+
left = SparseArray(np.arange(5, dtype='int64'))
900+
right = SparseArray(np.arange(5, dtype='int32'))
901+
902+
# passes with check_dtype=False
903+
assert_extension_array_equal(left, right, check_dtype=False)
904+
905+
# raises with check_dtype=True
906+
msg = textwrap.dedent("""\
907+
ExtensionArray are different
908+
909+
Attribute "dtype" are different
910+
\\[left\\]: Sparse\\[int64, 0\\]
911+
\\[right\\]: Sparse\\[int32, 0\\]""")
912+
with pytest.raises(AssertionError, match=msg):
913+
assert_extension_array_equal(left, right, check_dtype=True)
914+
915+
def test_missing_values(self):
916+
left = SparseArray([np.nan, 1, 2, np.nan])
917+
right = SparseArray([np.nan, 1, 2, 3])
918+
919+
msg = textwrap.dedent("""\
920+
ExtensionArray NA mask are different
921+
922+
ExtensionArray NA mask values are different \\(25\\.0 %\\)
923+
\\[left\\]: \\[True, False, False, True\\]
924+
\\[right\\]: \\[True, False, False, False\\]""")
925+
with pytest.raises(AssertionError, match=msg):
926+
assert_extension_array_equal(left, right)
927+
928+
def test_non_extension_array(self):
929+
numpy_array = np.arange(5)
930+
extension_array = SparseArray(np.arange(5))
931+
932+
msg = 'left is not an ExtensionArray'
933+
with pytest.raises(AssertionError, match=msg):
934+
assert_extension_array_equal(numpy_array, extension_array)
935+
936+
msg = 'right is not an ExtensionArray'
937+
with pytest.raises(AssertionError, match=msg):
938+
assert_extension_array_equal(extension_array, numpy_array)
939+
940+
853941
class TestRNGContext(object):
854942

855943
def test_RNGContext(self):

pandas/util/testing.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -1197,31 +1197,48 @@ def _raise(left, right, err_msg):
11971197
return True
11981198

11991199

1200-
def assert_extension_array_equal(left, right):
1200+
def assert_extension_array_equal(left, right, check_dtype=True,
1201+
check_less_precise=False,
1202+
check_exact=False):
12011203
"""Check that left and right ExtensionArrays are equal.
12021204
12031205
Parameters
12041206
----------
12051207
left, right : ExtensionArray
12061208
The two arrays to compare
1209+
check_dtype : bool, default True
1210+
Whether to check if the ExtensionArray dtypes are identical.
1211+
check_less_precise : bool or int, default False
1212+
Specify comparison precision. Only used when check_exact is False.
1213+
5 digits (False) or 3 digits (True) after decimal points are compared.
1214+
If int, then specify the digits to compare.
1215+
check_exact : bool, default False
1216+
Whether to compare number exactly.
12071217
12081218
Notes
12091219
-----
12101220
Missing values are checked separately from valid values.
12111221
A mask of missing values is computed for each and checked to match.
12121222
The remaining all-valid values are cast to object dtype and checked.
12131223
"""
1214-
assert isinstance(left, ExtensionArray)
1215-
assert left.dtype == right.dtype
1224+
assert isinstance(left, ExtensionArray), 'left is not an ExtensionArray'
1225+
assert isinstance(right, ExtensionArray), 'right is not an ExtensionArray'
1226+
if check_dtype:
1227+
assert_attr_equal('dtype', left, right, obj='ExtensionArray')
1228+
12161229
left_na = np.asarray(left.isna())
12171230
right_na = np.asarray(right.isna())
1218-
1219-
assert_numpy_array_equal(left_na, right_na)
1231+
assert_numpy_array_equal(left_na, right_na, obj='ExtensionArray NA mask')
12201232

12211233
left_valid = np.asarray(left[~left_na].astype(object))
12221234
right_valid = np.asarray(right[~right_na].astype(object))
1223-
1224-
assert_numpy_array_equal(left_valid, right_valid)
1235+
if check_exact:
1236+
assert_numpy_array_equal(left_valid, right_valid, obj='ExtensionArray')
1237+
else:
1238+
_testing.assert_almost_equal(left_valid, right_valid,
1239+
check_dtype=check_dtype,
1240+
check_less_precise=check_less_precise,
1241+
obj='ExtensionArray')
12251242

12261243

12271244
# This could be refactored to use the NDFrame.equals method

0 commit comments

Comments
 (0)