10
10
11
11
import numpy as np
12
12
13
+ from pandas ._libs import lib
13
14
from pandas ._libs .missing import is_matching_na
14
15
from pandas ._libs .sparse import SparseIndex
15
16
import pandas ._libs .testing as _testing
@@ -698,9 +699,9 @@ def assert_extension_array_equal(
698
699
right ,
699
700
check_dtype : bool | Literal ["equiv" ] = True ,
700
701
index_values = None ,
701
- check_exact : bool = False ,
702
- rtol : float = 1.0e-5 ,
703
- atol : float = 1.0e-8 ,
702
+ check_exact : bool | lib . NoDefault = lib . no_default ,
703
+ rtol : float | lib . NoDefault = lib . no_default ,
704
+ atol : float | lib . NoDefault = lib . no_default ,
704
705
obj : str = "ExtensionArray" ,
705
706
) -> None :
706
707
"""
@@ -715,7 +716,12 @@ def assert_extension_array_equal(
715
716
index_values : Index | numpy.ndarray, default None
716
717
Optional index (shared by both left and right), used in output.
717
718
check_exact : bool, default False
718
- Whether to compare number exactly. Only takes effect for float dtypes.
719
+ Whether to compare number exactly.
720
+
721
+ .. versionchanged:: 2.2.0
722
+
723
+ Defaults to True for integer dtypes if none of
724
+ ``check_exact``, ``rtol`` and ``atol`` are specified.
719
725
rtol : float, default 1e-5
720
726
Relative tolerance. Only used when check_exact is False.
721
727
atol : float, default 1e-8
@@ -739,6 +745,23 @@ def assert_extension_array_equal(
739
745
>>> b, c = a.array, a.array
740
746
>>> tm.assert_extension_array_equal(b, c)
741
747
"""
748
+ if (
749
+ check_exact is lib .no_default
750
+ and rtol is lib .no_default
751
+ and atol is lib .no_default
752
+ ):
753
+ check_exact = (
754
+ is_numeric_dtype (left .dtype )
755
+ and not is_float_dtype (left .dtype )
756
+ or is_numeric_dtype (right .dtype )
757
+ and not is_float_dtype (right .dtype )
758
+ )
759
+ elif check_exact is lib .no_default :
760
+ check_exact = False
761
+
762
+ rtol = rtol if rtol is not lib .no_default else 1.0e-5
763
+ atol = atol if atol is not lib .no_default else 1.0e-8
764
+
742
765
assert isinstance (left , ExtensionArray ), "left is not an ExtensionArray"
743
766
assert isinstance (right , ExtensionArray ), "right is not an ExtensionArray"
744
767
if check_dtype :
@@ -784,10 +807,7 @@ def assert_extension_array_equal(
784
807
785
808
left_valid = left [~ left_na ].to_numpy (dtype = object )
786
809
right_valid = right [~ right_na ].to_numpy (dtype = object )
787
- if check_exact or (
788
- (is_numeric_dtype (left .dtype ) and not is_float_dtype (left .dtype ))
789
- or (is_numeric_dtype (right .dtype ) and not is_float_dtype (right .dtype ))
790
- ):
810
+ if check_exact :
791
811
assert_numpy_array_equal (
792
812
left_valid , right_valid , obj = obj , index_values = index_values
793
813
)
@@ -811,14 +831,14 @@ def assert_series_equal(
811
831
check_index_type : bool | Literal ["equiv" ] = "equiv" ,
812
832
check_series_type : bool = True ,
813
833
check_names : bool = True ,
814
- check_exact : bool = False ,
834
+ check_exact : bool | lib . NoDefault = lib . no_default ,
815
835
check_datetimelike_compat : bool = False ,
816
836
check_categorical : bool = True ,
817
837
check_category_order : bool = True ,
818
838
check_freq : bool = True ,
819
839
check_flags : bool = True ,
820
- rtol : float = 1.0e-5 ,
821
- atol : float = 1.0e-8 ,
840
+ rtol : float | lib . NoDefault = lib . no_default ,
841
+ atol : float | lib . NoDefault = lib . no_default ,
822
842
obj : str = "Series" ,
823
843
* ,
824
844
check_index : bool = True ,
@@ -841,7 +861,12 @@ def assert_series_equal(
841
861
check_names : bool, default True
842
862
Whether to check the Series and Index names attribute.
843
863
check_exact : bool, default False
844
- Whether to compare number exactly. Only takes effect for float dtypes.
864
+ Whether to compare number exactly.
865
+
866
+ .. versionchanged:: 2.2.0
867
+
868
+ Defaults to True for integer dtypes if none of
869
+ ``check_exact``, ``rtol`` and ``atol`` are specified.
845
870
check_datetimelike_compat : bool, default False
846
871
Compare datetime-like which is comparable ignoring dtype.
847
872
check_categorical : bool, default True
@@ -877,6 +902,22 @@ def assert_series_equal(
877
902
>>> tm.assert_series_equal(a, b)
878
903
"""
879
904
__tracebackhide__ = True
905
+ if (
906
+ check_exact is lib .no_default
907
+ and rtol is lib .no_default
908
+ and atol is lib .no_default
909
+ ):
910
+ check_exact = (
911
+ is_numeric_dtype (left .dtype )
912
+ and not is_float_dtype (left .dtype )
913
+ or is_numeric_dtype (right .dtype )
914
+ and not is_float_dtype (right .dtype )
915
+ )
916
+ elif check_exact is lib .no_default :
917
+ check_exact = False
918
+
919
+ rtol = rtol if rtol is not lib .no_default else 1.0e-5
920
+ atol = atol if atol is not lib .no_default else 1.0e-8
880
921
881
922
if not check_index and check_like :
882
923
raise ValueError ("check_like must be False if check_index is False" )
@@ -931,10 +972,7 @@ def assert_series_equal(
931
972
pass
932
973
else :
933
974
assert_attr_equal ("dtype" , left , right , obj = f"Attributes of { obj } " )
934
- if check_exact or (
935
- (is_numeric_dtype (left .dtype ) and not is_float_dtype (left .dtype ))
936
- or (is_numeric_dtype (right .dtype ) and not is_float_dtype (right .dtype ))
937
- ):
975
+ if check_exact :
938
976
left_values = left ._values
939
977
right_values = right ._values
940
978
# Only check exact if dtype is numeric
@@ -1061,14 +1099,14 @@ def assert_frame_equal(
1061
1099
check_frame_type : bool = True ,
1062
1100
check_names : bool = True ,
1063
1101
by_blocks : bool = False ,
1064
- check_exact : bool = False ,
1102
+ check_exact : bool | lib . NoDefault = lib . no_default ,
1065
1103
check_datetimelike_compat : bool = False ,
1066
1104
check_categorical : bool = True ,
1067
1105
check_like : bool = False ,
1068
1106
check_freq : bool = True ,
1069
1107
check_flags : bool = True ,
1070
- rtol : float = 1.0e-5 ,
1071
- atol : float = 1.0e-8 ,
1108
+ rtol : float | lib . NoDefault = lib . no_default ,
1109
+ atol : float | lib . NoDefault = lib . no_default ,
1072
1110
obj : str = "DataFrame" ,
1073
1111
) -> None :
1074
1112
"""
@@ -1103,7 +1141,12 @@ def assert_frame_equal(
1103
1141
Specify how to compare internal data. If False, compare by columns.
1104
1142
If True, compare by blocks.
1105
1143
check_exact : bool, default False
1106
- Whether to compare number exactly. Only takes effect for float dtypes.
1144
+ Whether to compare number exactly.
1145
+
1146
+ .. versionchanged:: 2.2.0
1147
+
1148
+ Defaults to True for integer dtypes if none of
1149
+ ``check_exact``, ``rtol`` and ``atol`` are specified.
1107
1150
check_datetimelike_compat : bool, default False
1108
1151
Compare datetime-like which is comparable ignoring dtype.
1109
1152
check_categorical : bool, default True
@@ -1158,6 +1201,9 @@ def assert_frame_equal(
1158
1201
>>> assert_frame_equal(df1, df2, check_dtype=False)
1159
1202
"""
1160
1203
__tracebackhide__ = True
1204
+ _rtol = rtol if rtol is not lib .no_default else 1.0e-5
1205
+ _atol = atol if atol is not lib .no_default else 1.0e-8
1206
+ _check_exact = check_exact if check_exact is not lib .no_default else False
1161
1207
1162
1208
# instance validation
1163
1209
_check_isinstance (left , right , DataFrame )
@@ -1181,11 +1227,11 @@ def assert_frame_equal(
1181
1227
right .index ,
1182
1228
exact = check_index_type ,
1183
1229
check_names = check_names ,
1184
- check_exact = check_exact ,
1230
+ check_exact = _check_exact ,
1185
1231
check_categorical = check_categorical ,
1186
1232
check_order = not check_like ,
1187
- rtol = rtol ,
1188
- atol = atol ,
1233
+ rtol = _rtol ,
1234
+ atol = _atol ,
1189
1235
obj = f"{ obj } .index" ,
1190
1236
)
1191
1237
@@ -1195,11 +1241,11 @@ def assert_frame_equal(
1195
1241
right .columns ,
1196
1242
exact = check_column_type ,
1197
1243
check_names = check_names ,
1198
- check_exact = check_exact ,
1244
+ check_exact = _check_exact ,
1199
1245
check_categorical = check_categorical ,
1200
1246
check_order = not check_like ,
1201
- rtol = rtol ,
1202
- atol = atol ,
1247
+ rtol = _rtol ,
1248
+ atol = _atol ,
1203
1249
obj = f"{ obj } .columns" ,
1204
1250
)
1205
1251
0 commit comments