Skip to content

Add strict_na keyword to the assert_.._equal methods for object dtype to help with deprecation #58072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas/_libs/testing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def assert_almost_equal(
rtol: float = ...,
atol: float = ...,
check_dtype: bool = ...,
strict_na: bool = ...,
obj=...,
lobj=...,
robj=...,
Expand Down
13 changes: 9 additions & 4 deletions pandas/_libs/testing.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cpdef assert_dict_equal(a, b, bint compare_keys=True):

cpdef assert_almost_equal(a, b,
rtol=1.e-5, atol=1.e-8,
bint check_dtype=True,
bint check_dtype=True, bint strict_na=True,
obj=None, lobj=None, robj=None, index_values=None):
"""
Check that left and right objects are almost equal.
Expand All @@ -67,6 +67,7 @@ cpdef assert_almost_equal(a, b,
Absolute tolerance.
check_dtype: bool, default True
check dtype if both a and b are np.ndarray.
strict_na : bool, default True
obj : str, default None
Specify object name being compared, internally used to show
appropriate assertion message.
Expand Down Expand Up @@ -155,7 +156,9 @@ cpdef assert_almost_equal(a, b,

for i in range(len(a)):
try:
assert_almost_equal(a[i], b[i], rtol=rtol, atol=atol)
assert_almost_equal(
a[i], b[i], rtol=rtol, atol=atol, strict_na=strict_na
)
except AssertionError:
is_unequal = True
diff += 1
Expand Down Expand Up @@ -185,8 +188,10 @@ cpdef assert_almost_equal(a, b,
if is_matching_na(a, b, nan_matches_none=False):
return True
elif checknull(b):
# GH#18463
raise AssertionError(f"Mismatched null-like values {a} != {b}")
if strict_na:
raise AssertionError(f"Mismatched null-like values {a} != {b}")
else:
return True
raise AssertionError(f"{a} != {b}")
elif checknull(b):
raise AssertionError(f"{a} != {b}")
Expand Down
8 changes: 7 additions & 1 deletion pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def assert_almost_equal(
check_dtype: bool | Literal["equiv"] = "equiv",
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
strict_na: bool = True,
**kwargs,
) -> None:
"""
Expand All @@ -89,6 +90,7 @@ def assert_almost_equal(
Relative tolerance.
atol : float, default 1e-8
Absolute tolerance.
strict_na : bool, default True
"""
if isinstance(left, Index):
assert_index_equal(
Expand Down Expand Up @@ -141,7 +143,7 @@ def assert_almost_equal(

# if we have "equiv", this becomes True
_testing.assert_almost_equal(
left, right, check_dtype=bool(check_dtype), rtol=rtol, atol=atol, **kwargs
left, right, check_dtype=bool(check_dtype), rtol=rtol, atol=atol, strict_na=strict_na, **kwargs
)


Expand Down Expand Up @@ -839,6 +841,7 @@ def assert_series_equal(
check_flags: bool = True,
rtol: float | lib.NoDefault = lib.no_default,
atol: float | lib.NoDefault = lib.no_default,
strict_na: bool = True,
obj: str = "Series",
*,
check_index: bool = True,
Expand Down Expand Up @@ -1070,6 +1073,7 @@ def assert_series_equal(
rtol=rtol,
atol=atol,
check_dtype=bool(check_dtype),
strict_na=strict_na,
obj=str(obj),
index_values=left.index,
)
Expand Down Expand Up @@ -1108,6 +1112,7 @@ def assert_frame_equal(
check_flags: bool = True,
rtol: float | lib.NoDefault = lib.no_default,
atol: float | lib.NoDefault = lib.no_default,
strict_na: bool = True,
obj: str = "DataFrame",
) -> None:
"""
Expand Down Expand Up @@ -1291,6 +1296,7 @@ def assert_frame_equal(
atol=atol,
check_index=False,
check_flags=False,
strict_na=strict_na,
)


Expand Down
10 changes: 9 additions & 1 deletion pandas/tests/util/test_assert_almost_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ def test_mismatched_na_assert_almost_equal(left, right):
else:
with pytest.raises(AssertionError, match=msg):
_assert_almost_equal_both(left, right, check_dtype=False)

_assert_almost_equal_both(left, right, check_dtype=False, strict_na=False)

# TODO: to get the same deprecation in assert_numpy_array_equal we need
# to change/deprecate the default for strict_nan to become True
Expand All @@ -343,11 +345,17 @@ def test_mismatched_na_assert_almost_equal(left, right):
tm.assert_series_equal(
Series(left_arr, dtype=object), Series(right_arr, dtype=object)
)
tm.assert_series_equal(
Series(left_arr, dtype=object), Series(right_arr, dtype=object), strict_na=False
)

with pytest.raises(AssertionError, match="DataFrame.iloc.* are different"):
tm.assert_frame_equal(
DataFrame(left_arr, dtype=object), DataFrame(right_arr, dtype=object)
)

tm.assert_frame_equal(
DataFrame(left_arr, dtype=object), DataFrame(right_arr, dtype=object), strict_na=False
)

def test_assert_not_almost_equal_inf():
_assert_not_almost_equal_both(np.inf, 0)
Expand Down
Loading