From 15ef4706fc0ab3764668299ffee42c710be4452d Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 29 Mar 2024 15:27:29 +0100 Subject: [PATCH] Add strict_na keyword to the assert_.._equal methods for object dtype to help with deprecation --- pandas/_libs/testing.pyi | 1 + pandas/_libs/testing.pyx | 13 +++++++++---- pandas/_testing/asserters.py | 8 +++++++- pandas/tests/util/test_assert_almost_equal.py | 10 +++++++++- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/pandas/_libs/testing.pyi b/pandas/_libs/testing.pyi index ab87e58eba9b9..312ff17dca26e 100644 --- a/pandas/_libs/testing.pyi +++ b/pandas/_libs/testing.pyi @@ -7,6 +7,7 @@ def assert_almost_equal( rtol: float = ..., atol: float = ..., check_dtype: bool = ..., + strict_na: bool = ..., obj=..., lobj=..., robj=..., diff --git a/pandas/_libs/testing.pyx b/pandas/_libs/testing.pyx index cfd31fa610e69..929fbb57f5d8b 100644 --- a/pandas/_libs/testing.pyx +++ b/pandas/_libs/testing.pyx @@ -52,7 +52,7 @@ cpdef assert_dict_equal(a, b, bint compare_keys=True): cpdef assert_almost_equal(a, b, rtol=1.e-5, atol=1.e-8, - bint check_dtype=True, + bint check_dtype=True, bint strict_na=True, obj=None, lobj=None, robj=None, index_values=None): """ Check that left and right objects are almost equal. @@ -67,6 +67,7 @@ cpdef assert_almost_equal(a, b, Absolute tolerance. check_dtype: bool, default True check dtype if both a and b are np.ndarray. + strict_na : bool, default True obj : str, default None Specify object name being compared, internally used to show appropriate assertion message. @@ -155,7 +156,9 @@ cpdef assert_almost_equal(a, b, for i in range(len(a)): try: - assert_almost_equal(a[i], b[i], rtol=rtol, atol=atol) + assert_almost_equal( + a[i], b[i], rtol=rtol, atol=atol, strict_na=strict_na + ) except AssertionError: is_unequal = True diff += 1 @@ -185,8 +188,10 @@ cpdef assert_almost_equal(a, b, if is_matching_na(a, b, nan_matches_none=False): return True elif checknull(b): - # GH#18463 - raise AssertionError(f"Mismatched null-like values {a} != {b}") + if strict_na: + raise AssertionError(f"Mismatched null-like values {a} != {b}") + else: + return True raise AssertionError(f"{a} != {b}") elif checknull(b): raise AssertionError(f"{a} != {b}") diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py index 3aacd3099c334..eb1360970f750 100644 --- a/pandas/_testing/asserters.py +++ b/pandas/_testing/asserters.py @@ -69,6 +69,7 @@ def assert_almost_equal( check_dtype: bool | Literal["equiv"] = "equiv", rtol: float = 1.0e-5, atol: float = 1.0e-8, + strict_na: bool = True, **kwargs, ) -> None: """ @@ -89,6 +90,7 @@ def assert_almost_equal( Relative tolerance. atol : float, default 1e-8 Absolute tolerance. + strict_na : bool, default True """ if isinstance(left, Index): assert_index_equal( @@ -141,7 +143,7 @@ def assert_almost_equal( # if we have "equiv", this becomes True _testing.assert_almost_equal( - left, right, check_dtype=bool(check_dtype), rtol=rtol, atol=atol, **kwargs + left, right, check_dtype=bool(check_dtype), rtol=rtol, atol=atol, strict_na=strict_na, **kwargs ) @@ -839,6 +841,7 @@ def assert_series_equal( check_flags: bool = True, rtol: float | lib.NoDefault = lib.no_default, atol: float | lib.NoDefault = lib.no_default, + strict_na: bool = True, obj: str = "Series", *, check_index: bool = True, @@ -1070,6 +1073,7 @@ def assert_series_equal( rtol=rtol, atol=atol, check_dtype=bool(check_dtype), + strict_na=strict_na, obj=str(obj), index_values=left.index, ) @@ -1108,6 +1112,7 @@ def assert_frame_equal( check_flags: bool = True, rtol: float | lib.NoDefault = lib.no_default, atol: float | lib.NoDefault = lib.no_default, + strict_na: bool = True, obj: str = "DataFrame", ) -> None: """ @@ -1291,6 +1296,7 @@ def assert_frame_equal( atol=atol, check_index=False, check_flags=False, + strict_na=strict_na, ) diff --git a/pandas/tests/util/test_assert_almost_equal.py b/pandas/tests/util/test_assert_almost_equal.py index bcc2e4e03f367..9e2304d5bee89 100644 --- a/pandas/tests/util/test_assert_almost_equal.py +++ b/pandas/tests/util/test_assert_almost_equal.py @@ -333,6 +333,8 @@ def test_mismatched_na_assert_almost_equal(left, right): else: with pytest.raises(AssertionError, match=msg): _assert_almost_equal_both(left, right, check_dtype=False) + + _assert_almost_equal_both(left, right, check_dtype=False, strict_na=False) # TODO: to get the same deprecation in assert_numpy_array_equal we need # to change/deprecate the default for strict_nan to become True @@ -343,11 +345,17 @@ def test_mismatched_na_assert_almost_equal(left, right): tm.assert_series_equal( Series(left_arr, dtype=object), Series(right_arr, dtype=object) ) + tm.assert_series_equal( + Series(left_arr, dtype=object), Series(right_arr, dtype=object), strict_na=False + ) + with pytest.raises(AssertionError, match="DataFrame.iloc.* are different"): tm.assert_frame_equal( DataFrame(left_arr, dtype=object), DataFrame(right_arr, dtype=object) ) - + tm.assert_frame_equal( + DataFrame(left_arr, dtype=object), DataFrame(right_arr, dtype=object), strict_na=False + ) def test_assert_not_almost_equal_inf(): _assert_not_almost_equal_both(np.inf, 0)