Skip to content

Commit be35d67

Browse files
phoflpmhatre1
authored andcommitted
TST: Don't ignore tolerance for integer series (pandas-dev#56724)
1 parent dbcdb37 commit be35d67

File tree

2 files changed

+84
-26
lines changed

2 files changed

+84
-26
lines changed

pandas/_testing/asserters.py

+72-26
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212

13+
from pandas._libs import lib
1314
from pandas._libs.missing import is_matching_na
1415
from pandas._libs.sparse import SparseIndex
1516
import pandas._libs.testing as _testing
@@ -698,9 +699,9 @@ def assert_extension_array_equal(
698699
right,
699700
check_dtype: bool | Literal["equiv"] = True,
700701
index_values=None,
701-
check_exact: bool = False,
702-
rtol: float = 1.0e-5,
703-
atol: float = 1.0e-8,
702+
check_exact: bool | lib.NoDefault = lib.no_default,
703+
rtol: float | lib.NoDefault = lib.no_default,
704+
atol: float | lib.NoDefault = lib.no_default,
704705
obj: str = "ExtensionArray",
705706
) -> None:
706707
"""
@@ -715,7 +716,12 @@ def assert_extension_array_equal(
715716
index_values : Index | numpy.ndarray, default None
716717
Optional index (shared by both left and right), used in output.
717718
check_exact : bool, default False
718-
Whether to compare number exactly. Only takes effect for float dtypes.
719+
Whether to compare number exactly.
720+
721+
.. versionchanged:: 2.2.0
722+
723+
Defaults to True for integer dtypes if none of
724+
``check_exact``, ``rtol`` and ``atol`` are specified.
719725
rtol : float, default 1e-5
720726
Relative tolerance. Only used when check_exact is False.
721727
atol : float, default 1e-8
@@ -739,6 +745,23 @@ def assert_extension_array_equal(
739745
>>> b, c = a.array, a.array
740746
>>> tm.assert_extension_array_equal(b, c)
741747
"""
748+
if (
749+
check_exact is lib.no_default
750+
and rtol is lib.no_default
751+
and atol is lib.no_default
752+
):
753+
check_exact = (
754+
is_numeric_dtype(left.dtype)
755+
and not is_float_dtype(left.dtype)
756+
or is_numeric_dtype(right.dtype)
757+
and not is_float_dtype(right.dtype)
758+
)
759+
elif check_exact is lib.no_default:
760+
check_exact = False
761+
762+
rtol = rtol if rtol is not lib.no_default else 1.0e-5
763+
atol = atol if atol is not lib.no_default else 1.0e-8
764+
742765
assert isinstance(left, ExtensionArray), "left is not an ExtensionArray"
743766
assert isinstance(right, ExtensionArray), "right is not an ExtensionArray"
744767
if check_dtype:
@@ -784,10 +807,7 @@ def assert_extension_array_equal(
784807

785808
left_valid = left[~left_na].to_numpy(dtype=object)
786809
right_valid = right[~right_na].to_numpy(dtype=object)
787-
if check_exact or (
788-
(is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype))
789-
or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype))
790-
):
810+
if check_exact:
791811
assert_numpy_array_equal(
792812
left_valid, right_valid, obj=obj, index_values=index_values
793813
)
@@ -811,14 +831,14 @@ def assert_series_equal(
811831
check_index_type: bool | Literal["equiv"] = "equiv",
812832
check_series_type: bool = True,
813833
check_names: bool = True,
814-
check_exact: bool = False,
834+
check_exact: bool | lib.NoDefault = lib.no_default,
815835
check_datetimelike_compat: bool = False,
816836
check_categorical: bool = True,
817837
check_category_order: bool = True,
818838
check_freq: bool = True,
819839
check_flags: bool = True,
820-
rtol: float = 1.0e-5,
821-
atol: float = 1.0e-8,
840+
rtol: float | lib.NoDefault = lib.no_default,
841+
atol: float | lib.NoDefault = lib.no_default,
822842
obj: str = "Series",
823843
*,
824844
check_index: bool = True,
@@ -841,7 +861,12 @@ def assert_series_equal(
841861
check_names : bool, default True
842862
Whether to check the Series and Index names attribute.
843863
check_exact : bool, default False
844-
Whether to compare number exactly. Only takes effect for float dtypes.
864+
Whether to compare number exactly.
865+
866+
.. versionchanged:: 2.2.0
867+
868+
Defaults to True for integer dtypes if none of
869+
``check_exact``, ``rtol`` and ``atol`` are specified.
845870
check_datetimelike_compat : bool, default False
846871
Compare datetime-like which is comparable ignoring dtype.
847872
check_categorical : bool, default True
@@ -877,6 +902,22 @@ def assert_series_equal(
877902
>>> tm.assert_series_equal(a, b)
878903
"""
879904
__tracebackhide__ = True
905+
if (
906+
check_exact is lib.no_default
907+
and rtol is lib.no_default
908+
and atol is lib.no_default
909+
):
910+
check_exact = (
911+
is_numeric_dtype(left.dtype)
912+
and not is_float_dtype(left.dtype)
913+
or is_numeric_dtype(right.dtype)
914+
and not is_float_dtype(right.dtype)
915+
)
916+
elif check_exact is lib.no_default:
917+
check_exact = False
918+
919+
rtol = rtol if rtol is not lib.no_default else 1.0e-5
920+
atol = atol if atol is not lib.no_default else 1.0e-8
880921

881922
if not check_index and check_like:
882923
raise ValueError("check_like must be False if check_index is False")
@@ -931,10 +972,7 @@ def assert_series_equal(
931972
pass
932973
else:
933974
assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
934-
if check_exact or (
935-
(is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype))
936-
or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype))
937-
):
975+
if check_exact:
938976
left_values = left._values
939977
right_values = right._values
940978
# Only check exact if dtype is numeric
@@ -1061,14 +1099,14 @@ def assert_frame_equal(
10611099
check_frame_type: bool = True,
10621100
check_names: bool = True,
10631101
by_blocks: bool = False,
1064-
check_exact: bool = False,
1102+
check_exact: bool | lib.NoDefault = lib.no_default,
10651103
check_datetimelike_compat: bool = False,
10661104
check_categorical: bool = True,
10671105
check_like: bool = False,
10681106
check_freq: bool = True,
10691107
check_flags: bool = True,
1070-
rtol: float = 1.0e-5,
1071-
atol: float = 1.0e-8,
1108+
rtol: float | lib.NoDefault = lib.no_default,
1109+
atol: float | lib.NoDefault = lib.no_default,
10721110
obj: str = "DataFrame",
10731111
) -> None:
10741112
"""
@@ -1103,7 +1141,12 @@ def assert_frame_equal(
11031141
Specify how to compare internal data. If False, compare by columns.
11041142
If True, compare by blocks.
11051143
check_exact : bool, default False
1106-
Whether to compare number exactly. Only takes effect for float dtypes.
1144+
Whether to compare number exactly.
1145+
1146+
.. versionchanged:: 2.2.0
1147+
1148+
Defaults to True for integer dtypes if none of
1149+
``check_exact``, ``rtol`` and ``atol`` are specified.
11071150
check_datetimelike_compat : bool, default False
11081151
Compare datetime-like which is comparable ignoring dtype.
11091152
check_categorical : bool, default True
@@ -1158,6 +1201,9 @@ def assert_frame_equal(
11581201
>>> assert_frame_equal(df1, df2, check_dtype=False)
11591202
"""
11601203
__tracebackhide__ = True
1204+
_rtol = rtol if rtol is not lib.no_default else 1.0e-5
1205+
_atol = atol if atol is not lib.no_default else 1.0e-8
1206+
_check_exact = check_exact if check_exact is not lib.no_default else False
11611207

11621208
# instance validation
11631209
_check_isinstance(left, right, DataFrame)
@@ -1181,11 +1227,11 @@ def assert_frame_equal(
11811227
right.index,
11821228
exact=check_index_type,
11831229
check_names=check_names,
1184-
check_exact=check_exact,
1230+
check_exact=_check_exact,
11851231
check_categorical=check_categorical,
11861232
check_order=not check_like,
1187-
rtol=rtol,
1188-
atol=atol,
1233+
rtol=_rtol,
1234+
atol=_atol,
11891235
obj=f"{obj}.index",
11901236
)
11911237

@@ -1195,11 +1241,11 @@ def assert_frame_equal(
11951241
right.columns,
11961242
exact=check_column_type,
11971243
check_names=check_names,
1198-
check_exact=check_exact,
1244+
check_exact=_check_exact,
11991245
check_categorical=check_categorical,
12001246
check_order=not check_like,
1201-
rtol=rtol,
1202-
atol=atol,
1247+
rtol=_rtol,
1248+
atol=_atol,
12031249
obj=f"{obj}.columns",
12041250
)
12051251

pandas/tests/util/test_assert_series_equal.py

+12
Original file line numberDiff line numberDiff line change
@@ -461,3 +461,15 @@ def test_ea_and_numpy_no_dtype_check(val, check_exact, dtype):
461461
left = Series([1, 2, val], dtype=dtype)
462462
right = Series(pd.array([1, 2, val]))
463463
tm.assert_series_equal(left, right, check_dtype=False, check_exact=check_exact)
464+
465+
466+
def test_assert_series_equal_int_tol():
467+
# GH#56646
468+
left = Series([81, 18, 121, 38, 74, 72, 81, 81, 146, 81, 81, 170, 74, 74])
469+
right = Series([72, 9, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72])
470+
tm.assert_series_equal(left, right, rtol=1.5)
471+
472+
tm.assert_frame_equal(left.to_frame(), right.to_frame(), rtol=1.5)
473+
tm.assert_extension_array_equal(
474+
left.astype("Int64").values, right.astype("Int64").values, rtol=1.5
475+
)

0 commit comments

Comments
 (0)