Skip to content

Commit 07b363e

Browse files
authored
ENH: Include column for ea comparison in asserters (#50323)
* ENH: Include column for ea comparison in asserters * Add gh ref * Fix test * Add gh ref * Split tests
1 parent 9015013 commit 07b363e

File tree

5 files changed

+56
-14
lines changed

5 files changed

+56
-14
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ Other enhancements
9393
- :func:`timedelta_range` now supports a ``unit`` keyword ("s", "ms", "us", or "ns") to specify the desired resolution of the output index (:issue:`49824`)
9494
- :meth:`DataFrame.to_json` now supports a ``mode`` keyword with supported inputs 'w' and 'a'. Defaulting to 'w', 'a' can be used when lines=True and orient='records' to append record oriented json lines to an existing json file. (:issue:`35849`)
9595
- Added ``name`` parameter to :meth:`IntervalIndex.from_breaks`, :meth:`IntervalIndex.from_arrays` and :meth:`IntervalIndex.from_tuples` (:issue:`48911`)
96+
- Improve exception message when using :func:`assert_frame_equal` on a :class:`DataFrame` to include the column that is compared (:issue:`50323`)
9697
- Improved error message for :func:`merge_asof` when join-columns were duplicated (:issue:`50102`)
9798
- Added :meth:`Index.infer_objects` analogous to :meth:`Series.infer_objects` (:issue:`50034`)
9899
- Added ``copy`` parameter to :meth:`Series.infer_objects` and :meth:`DataFrame.infer_objects`, passing ``False`` will avoid making copies for series or columns that are already non-object or where no better dtype can be inferred (:issue:`50096`)

pandas/_testing/asserters.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def assert_extension_array_equal(
680680
check_exact: bool = False,
681681
rtol: float = 1.0e-5,
682682
atol: float = 1.0e-8,
683+
obj: str = "ExtensionArray",
683684
) -> None:
684685
"""
685686
Check that left and right ExtensionArrays are equal.
@@ -702,6 +703,11 @@ def assert_extension_array_equal(
702703
Absolute tolerance. Only used when check_exact is False.
703704
704705
.. versionadded:: 1.1.0
706+
obj : str, default 'ExtensionArray'
707+
Specify object name being compared, internally used to show appropriate
708+
assertion message.
709+
710+
.. versionadded:: 2.0.0
705711
706712
Notes
707713
-----
@@ -719,7 +725,7 @@ def assert_extension_array_equal(
719725
assert isinstance(left, ExtensionArray), "left is not an ExtensionArray"
720726
assert isinstance(right, ExtensionArray), "right is not an ExtensionArray"
721727
if check_dtype:
722-
assert_attr_equal("dtype", left, right, obj="ExtensionArray")
728+
assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
723729

724730
if (
725731
isinstance(left, DatetimeLikeArrayMixin)
@@ -729,21 +735,24 @@ def assert_extension_array_equal(
729735
# Avoid slow object-dtype comparisons
730736
# np.asarray for case where we have a np.MaskedArray
731737
assert_numpy_array_equal(
732-
np.asarray(left.asi8), np.asarray(right.asi8), index_values=index_values
738+
np.asarray(left.asi8),
739+
np.asarray(right.asi8),
740+
index_values=index_values,
741+
obj=obj,
733742
)
734743
return
735744

736745
left_na = np.asarray(left.isna())
737746
right_na = np.asarray(right.isna())
738747
assert_numpy_array_equal(
739-
left_na, right_na, obj="ExtensionArray NA mask", index_values=index_values
748+
left_na, right_na, obj=f"{obj} NA mask", index_values=index_values
740749
)
741750

742751
left_valid = left[~left_na].to_numpy(dtype=object)
743752
right_valid = right[~right_na].to_numpy(dtype=object)
744753
if check_exact:
745754
assert_numpy_array_equal(
746-
left_valid, right_valid, obj="ExtensionArray", index_values=index_values
755+
left_valid, right_valid, obj=obj, index_values=index_values
747756
)
748757
else:
749758
_testing.assert_almost_equal(
@@ -752,7 +761,7 @@ def assert_extension_array_equal(
752761
check_dtype=bool(check_dtype),
753762
rtol=rtol,
754763
atol=atol,
755-
obj="ExtensionArray",
764+
obj=obj,
756765
index_values=index_values,
757766
)
758767

@@ -909,6 +918,7 @@ def assert_series_equal(
909918
right_values,
910919
check_dtype=check_dtype,
911920
index_values=np.asarray(left.index),
921+
obj=str(obj),
912922
)
913923
else:
914924
assert_numpy_array_equal(
@@ -955,6 +965,7 @@ def assert_series_equal(
955965
atol=atol,
956966
check_dtype=check_dtype,
957967
index_values=np.asarray(left.index),
968+
obj=str(obj),
958969
)
959970
elif is_extension_array_dtype_and_needs_i8_conversion(
960971
left.dtype, right.dtype
@@ -964,6 +975,7 @@ def assert_series_equal(
964975
right._values,
965976
check_dtype=check_dtype,
966977
index_values=np.asarray(left.index),
978+
obj=str(obj),
967979
)
968980
elif needs_i8_conversion(left.dtype) and needs_i8_conversion(right.dtype):
969981
# DatetimeArray or TimedeltaArray
@@ -972,6 +984,7 @@ def assert_series_equal(
972984
right._values,
973985
check_dtype=check_dtype,
974986
index_values=np.asarray(left.index),
987+
obj=str(obj),
975988
)
976989
else:
977990
_testing.assert_almost_equal(

pandas/tests/extension/json/test_json.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_custom_asserts(self):
140140
self.assert_frame_equal(a.to_frame(), a.to_frame())
141141

142142
b = pd.Series(data.take([0, 0, 1]))
143-
msg = r"ExtensionArray are different"
143+
msg = r"Series are different"
144144
with pytest.raises(AssertionError, match=msg):
145145
self.assert_series_equal(a, b)
146146

pandas/tests/util/test_assert_frame_equal.py

+33
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,36 @@ def test_assert_frame_equal_check_like_categorical_midx():
366366
),
367367
)
368368
tm.assert_frame_equal(left, right, check_like=True)
369+
370+
371+
def test_assert_frame_equal_ea_column_definition_in_exception_mask():
372+
# GH#50323
373+
df1 = DataFrame({"a": pd.Series([pd.NA, 1], dtype="Int64")})
374+
df2 = DataFrame({"a": pd.Series([1, 1], dtype="Int64")})
375+
376+
msg = r'DataFrame.iloc\[:, 0\] \(column name="a"\) NA mask values are different'
377+
with pytest.raises(AssertionError, match=msg):
378+
tm.assert_frame_equal(df1, df2)
379+
380+
381+
def test_assert_frame_equal_ea_column_definition_in_exception():
382+
# GH#50323
383+
df1 = DataFrame({"a": pd.Series([pd.NA, 1], dtype="Int64")})
384+
df2 = DataFrame({"a": pd.Series([pd.NA, 2], dtype="Int64")})
385+
386+
msg = r'DataFrame.iloc\[:, 0\] \(column name="a"\) values are different'
387+
with pytest.raises(AssertionError, match=msg):
388+
tm.assert_frame_equal(df1, df2)
389+
390+
with pytest.raises(AssertionError, match=msg):
391+
tm.assert_frame_equal(df1, df2, check_exact=True)
392+
393+
394+
def test_assert_frame_equal_ts_column():
395+
# GH#50323
396+
df1 = DataFrame({"a": [pd.Timestamp("2019-12-31"), pd.Timestamp("2020-12-31")]})
397+
df2 = DataFrame({"a": [pd.Timestamp("2020-12-31"), pd.Timestamp("2020-12-31")]})
398+
399+
msg = r'DataFrame.iloc\[:, 0\] \(column name="a"\) values are different'
400+
with pytest.raises(AssertionError, match=msg):
401+
tm.assert_frame_equal(df1, df2)

pandas/tests/util/test_assert_series_equal.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
from pandas.core.dtypes.common import is_extension_array_dtype
5-
64
import pandas as pd
75
from pandas import (
86
Categorical,
@@ -116,10 +114,7 @@ def test_less_precise(data1, data2, dtype, decimals):
116114
s2 = Series([data2], dtype=dtype)
117115

118116
if decimals in (5, 10) or (decimals >= 3 and abs(data1 - data2) >= 0.0005):
119-
if is_extension_array_dtype(dtype):
120-
msg = "ExtensionArray are different"
121-
else:
122-
msg = "Series values are different"
117+
msg = "Series values are different"
123118
with pytest.raises(AssertionError, match=msg):
124119
tm.assert_series_equal(s1, s2, rtol=rtol)
125120
else:
@@ -237,9 +232,9 @@ def test_series_equal_categorical_values_mismatch(rtol):
237232

238233

239234
def test_series_equal_datetime_values_mismatch(rtol):
240-
msg = """numpy array are different
235+
msg = """Series are different
241236
242-
numpy array values are different \\(100.0 %\\)
237+
Series values are different \\(100.0 %\\)
243238
\\[index\\]: \\[0, 1, 2\\]
244239
\\[left\\]: \\[1514764800000000000, 1514851200000000000, 1514937600000000000\\]
245240
\\[right\\]: \\[1549065600000000000, 1549152000000000000, 1549238400000000000\\]"""

0 commit comments

Comments
 (0)