Skip to content

TST: Don't ignore tolerance for integer series #56724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 72 additions & 26 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np

from pandas._libs import lib
from pandas._libs.missing import is_matching_na
from pandas._libs.sparse import SparseIndex
import pandas._libs.testing as _testing
Expand Down Expand Up @@ -698,9 +699,9 @@ def assert_extension_array_equal(
right,
check_dtype: bool | Literal["equiv"] = True,
index_values=None,
check_exact: bool = False,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
check_exact: bool | lib.NoDefault = lib.no_default,
rtol: float | lib.NoDefault = lib.no_default,
atol: float | lib.NoDefault = lib.no_default,
obj: str = "ExtensionArray",
) -> None:
"""
Expand All @@ -715,7 +716,12 @@ def assert_extension_array_equal(
index_values : Index | numpy.ndarray, default None
Optional index (shared by both left and right), used in output.
check_exact : bool, default False
Whether to compare number exactly. Only takes effect for float dtypes.
Whether to compare number exactly.

.. versionchanged:: 2.2.0

Defaults to True for integer dtypes if none of
``check_exact``, ``rtol`` and ``atol`` are specified.
rtol : float, default 1e-5
Relative tolerance. Only used when check_exact is False.
atol : float, default 1e-8
Expand All @@ -739,6 +745,23 @@ def assert_extension_array_equal(
>>> b, c = a.array, a.array
>>> tm.assert_extension_array_equal(b, c)
"""
if (
check_exact is lib.no_default
and rtol is lib.no_default
and atol is lib.no_default
):
check_exact = (
is_numeric_dtype(left.dtype)
and not is_float_dtype(left.dtype)
or is_numeric_dtype(right.dtype)
and not is_float_dtype(right.dtype)
)
Comment on lines +753 to +758
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick - I can never remember how a logic statement like this will be parsed, fancy making it move explicit with

        check_exact = (
            (is_numeric_dtype(left.dtype)
            and not is_float_dtype(left.dtype))
            or (is_numeric_dtype(right.dtype)
            and not is_float_dtype(right.dtype))
        )

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would make the intendation worse? Not sure how black treats those

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

up to you, but it's very easy to get these wrong, e.g. https://github.com/pandas-dev/pandas/pull/34334/files

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intents look really off with parenthesis, will merge for now

elif check_exact is lib.no_default:
check_exact = False

rtol = rtol if rtol is not lib.no_default else 1.0e-5
atol = atol if atol is not lib.no_default else 1.0e-8

assert isinstance(left, ExtensionArray), "left is not an ExtensionArray"
assert isinstance(right, ExtensionArray), "right is not an ExtensionArray"
if check_dtype:
Expand Down Expand Up @@ -784,10 +807,7 @@ def assert_extension_array_equal(

left_valid = left[~left_na].to_numpy(dtype=object)
right_valid = right[~right_na].to_numpy(dtype=object)
if check_exact or (
(is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype))
or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype))
):
if check_exact:
assert_numpy_array_equal(
left_valid, right_valid, obj=obj, index_values=index_values
)
Expand All @@ -811,14 +831,14 @@ def assert_series_equal(
check_index_type: bool | Literal["equiv"] = "equiv",
check_series_type: bool = True,
check_names: bool = True,
check_exact: bool = False,
check_exact: bool | lib.NoDefault = lib.no_default,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
check_category_order: bool = True,
check_freq: bool = True,
check_flags: bool = True,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
rtol: float | lib.NoDefault = lib.no_default,
atol: float | lib.NoDefault = lib.no_default,
obj: str = "Series",
*,
check_index: bool = True,
Expand All @@ -841,7 +861,12 @@ def assert_series_equal(
check_names : bool, default True
Whether to check the Series and Index names attribute.
check_exact : bool, default False
Whether to compare number exactly. Only takes effect for float dtypes.
Whether to compare number exactly.

.. versionchanged:: 2.2.0

Defaults to True for integer dtypes if none of
``check_exact``, ``rtol`` and ``atol`` are specified.
check_datetimelike_compat : bool, default False
Compare datetime-like which is comparable ignoring dtype.
check_categorical : bool, default True
Expand Down Expand Up @@ -877,6 +902,22 @@ def assert_series_equal(
>>> tm.assert_series_equal(a, b)
"""
__tracebackhide__ = True
if (
check_exact is lib.no_default
and rtol is lib.no_default
and atol is lib.no_default
):
check_exact = (
is_numeric_dtype(left.dtype)
and not is_float_dtype(left.dtype)
or is_numeric_dtype(right.dtype)
and not is_float_dtype(right.dtype)
)
elif check_exact is lib.no_default:
check_exact = False

rtol = rtol if rtol is not lib.no_default else 1.0e-5
atol = atol if atol is not lib.no_default else 1.0e-8

if not check_index and check_like:
raise ValueError("check_like must be False if check_index is False")
Expand Down Expand Up @@ -931,10 +972,7 @@ def assert_series_equal(
pass
else:
assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
if check_exact or (
(is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype))
or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype))
):
if check_exact:
left_values = left._values
right_values = right._values
# Only check exact if dtype is numeric
Expand Down Expand Up @@ -1061,14 +1099,14 @@ def assert_frame_equal(
check_frame_type: bool = True,
check_names: bool = True,
by_blocks: bool = False,
check_exact: bool = False,
check_exact: bool | lib.NoDefault = lib.no_default,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
check_like: bool = False,
check_freq: bool = True,
check_flags: bool = True,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
rtol: float | lib.NoDefault = lib.no_default,
atol: float | lib.NoDefault = lib.no_default,
obj: str = "DataFrame",
) -> None:
"""
Expand Down Expand Up @@ -1103,7 +1141,12 @@ def assert_frame_equal(
Specify how to compare internal data. If False, compare by columns.
If True, compare by blocks.
check_exact : bool, default False
Whether to compare number exactly. Only takes effect for float dtypes.
Whether to compare number exactly.

.. versionchanged:: 2.2.0

Defaults to True for integer dtypes if none of
``check_exact``, ``rtol`` and ``atol`` are specified.
check_datetimelike_compat : bool, default False
Compare datetime-like which is comparable ignoring dtype.
check_categorical : bool, default True
Expand Down Expand Up @@ -1158,6 +1201,9 @@ def assert_frame_equal(
>>> assert_frame_equal(df1, df2, check_dtype=False)
"""
__tracebackhide__ = True
_rtol = rtol if rtol is not lib.no_default else 1.0e-5
_atol = atol if atol is not lib.no_default else 1.0e-8
_check_exact = check_exact if check_exact is not lib.no_default else False

# instance validation
_check_isinstance(left, right, DataFrame)
Expand All @@ -1181,11 +1227,11 @@ def assert_frame_equal(
right.index,
exact=check_index_type,
check_names=check_names,
check_exact=check_exact,
check_exact=_check_exact,
check_categorical=check_categorical,
check_order=not check_like,
rtol=rtol,
atol=atol,
rtol=_rtol,
atol=_atol,
obj=f"{obj}.index",
)

Expand All @@ -1195,11 +1241,11 @@ def assert_frame_equal(
right.columns,
exact=check_column_type,
check_names=check_names,
check_exact=check_exact,
check_exact=_check_exact,
check_categorical=check_categorical,
check_order=not check_like,
rtol=rtol,
atol=atol,
rtol=_rtol,
atol=_atol,
obj=f"{obj}.columns",
)

Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/util/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,15 @@ def test_ea_and_numpy_no_dtype_check(val, check_exact, dtype):
left = Series([1, 2, val], dtype=dtype)
right = Series(pd.array([1, 2, val]))
tm.assert_series_equal(left, right, check_dtype=False, check_exact=check_exact)


def test_assert_series_equal_int_tol():
# GH#56646
left = Series([81, 18, 121, 38, 74, 72, 81, 81, 146, 81, 81, 170, 74, 74])
right = Series([72, 9, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72])
tm.assert_series_equal(left, right, rtol=1.5)

tm.assert_frame_equal(left.to_frame(), right.to_frame(), rtol=1.5)
tm.assert_extension_array_equal(
left.astype("Int64").values, right.astype("Int64").values, rtol=1.5
)