8
8
from shutil import rmtree
9
9
import string
10
10
import tempfile
11
- from typing import Union , cast
11
+ from typing import Optional , Union , cast
12
12
import warnings
13
13
import zipfile
14
14
53
53
Series ,
54
54
bdate_range ,
55
55
)
56
+ from pandas ._typing import AnyArrayLike
56
57
from pandas .core .algorithms import take_1d
57
58
from pandas .core .arrays import (
58
59
DatetimeArray ,
@@ -806,8 +807,12 @@ def assert_is_sorted(seq):
806
807
807
808
808
809
def assert_categorical_equal (
809
- left , right , check_dtype = True , check_category_order = True , obj = "Categorical"
810
- ):
810
+ left : Categorical ,
811
+ right : Categorical ,
812
+ check_dtype : bool = True ,
813
+ check_category_order : bool = True ,
814
+ obj : str = "Categorical"
815
+ ) -> None :
811
816
"""Test that Categoricals are equivalent.
812
817
813
818
Parameters
@@ -852,7 +857,12 @@ def assert_categorical_equal(
852
857
assert_attr_equal ("ordered" , left , right , obj = obj )
853
858
854
859
855
- def assert_interval_array_equal (left , right , exact = "equiv" , obj = "IntervalArray" ):
860
+ def assert_interval_array_equal (
861
+ left : IntervalArray ,
862
+ right : IntervalArray ,
863
+ exact : str = "equiv" ,
864
+ obj : str = "IntervalArray"
865
+ ) -> None :
856
866
"""Test that two IntervalArrays are equivalent.
857
867
858
868
Parameters
@@ -878,7 +888,11 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
878
888
assert_attr_equal ("closed" , left , right , obj = obj )
879
889
880
890
881
- def assert_period_array_equal (left , right , obj = "PeriodArray" ):
891
+ def assert_period_array_equal (
892
+ left : PeriodArray ,
893
+ right : PeriodArray ,
894
+ obj : str = "PeriodArray"
895
+ ) -> None :
882
896
_check_isinstance (left , right , PeriodArray )
883
897
884
898
assert_numpy_array_equal (
@@ -887,7 +901,11 @@ def assert_period_array_equal(left, right, obj="PeriodArray"):
887
901
assert_attr_equal ("freq" , left , right , obj = obj )
888
902
889
903
890
- def assert_datetime_array_equal (left , right , obj = "DatetimeArray" ):
904
+ def assert_datetime_array_equal (
905
+ left : DatetimeArray ,
906
+ right : DatetimeArray ,
907
+ obj : str = "DatetimeArray"
908
+ ) -> None :
891
909
__tracebackhide__ = True
892
910
_check_isinstance (left , right , DatetimeArray )
893
911
@@ -896,7 +914,11 @@ def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
896
914
assert_attr_equal ("tz" , left , right , obj = obj )
897
915
898
916
899
- def assert_timedelta_array_equal (left , right , obj = "TimedeltaArray" ):
917
+ def assert_timedelta_array_equal (
918
+ left : TimedeltaArray ,
919
+ right : TimedeltaArray ,
920
+ obj : str = "TimedeltaArray"
921
+ ) -> None :
900
922
__tracebackhide__ = True
901
923
_check_isinstance (left , right , TimedeltaArray )
902
924
assert_numpy_array_equal (left ._data , right ._data , obj = "{obj}._data" .format (obj = obj ))
@@ -931,13 +953,13 @@ def raise_assert_detail(obj, message, left, right, diff=None):
931
953
932
954
933
955
def assert_numpy_array_equal (
934
- left ,
935
- right ,
936
- strict_nan = False ,
937
- check_dtype = True ,
938
- err_msg = None ,
939
- check_same = None ,
940
- obj = "numpy array" ,
956
+ left : np . ndarray ,
957
+ right : np . ndarray ,
958
+ strict_nan : bool = False ,
959
+ check_dtype : bool = True ,
960
+ err_msg : Optional [ str ] = None ,
961
+ check_same : Optional [ str ] = None ,
962
+ obj : str = "numpy array" ,
941
963
):
942
964
""" Checks that 'np.ndarray' is equivalent
943
965
@@ -1067,18 +1089,18 @@ def assert_extension_array_equal(
1067
1089
1068
1090
# This could be refactored to use the NDFrame.equals method
1069
1091
def assert_series_equal (
1070
- left ,
1071
- right ,
1072
- check_dtype = True ,
1073
- check_index_type = "equiv" ,
1074
- check_series_type = True ,
1075
- check_less_precise = False ,
1076
- check_names = True ,
1077
- check_exact = False ,
1078
- check_datetimelike_compat = False ,
1079
- check_categorical = True ,
1080
- obj = "Series" ,
1081
- ):
1092
+ left : Series ,
1093
+ right : Series ,
1094
+ check_dtype : bool = True ,
1095
+ check_index_type : str = "equiv" ,
1096
+ check_series_type : bool = True ,
1097
+ check_less_precise : bool = False ,
1098
+ check_names : bool = True ,
1099
+ check_exact : bool = False ,
1100
+ check_datetimelike_compat : bool = False ,
1101
+ check_categorical : bool = True ,
1102
+ obj : str = "Series" ,
1103
+ ) -> None :
1082
1104
"""
1083
1105
Check that left and right Series are equal.
1084
1106
@@ -1185,8 +1207,13 @@ def assert_series_equal(
1185
1207
right ._internal_get_values (),
1186
1208
check_dtype = check_dtype ,
1187
1209
)
1188
- elif is_interval_dtype (left ) or is_interval_dtype (right ):
1189
- assert_interval_array_equal (left .array , right .array )
1210
+ elif is_interval_dtype (left ) or is_interval_dtype (left ):
1211
+ # must cast to interval dtype to keep mypy happy
1212
+ assert is_interval_dtype (right )
1213
+ assert is_interval_dtype (left )
1214
+ left_array = IntervalArray (left .array )
1215
+ right_array = IntervalArray (right .array )
1216
+ assert_interval_array_equal (left_array , right_array )
1190
1217
elif is_extension_array_dtype (left .dtype ) and is_datetime64tz_dtype (left .dtype ):
1191
1218
# .values is an ndarray, but ._values is the ExtensionArray.
1192
1219
# TODO: Use .array
@@ -1403,7 +1430,11 @@ def assert_frame_equal(
1403
1430
)
1404
1431
1405
1432
1406
- def assert_equal (left , right , ** kwargs ):
1433
+ def assert_equal (
1434
+ left : Union [DataFrame , AnyArrayLike ],
1435
+ right : Union [DataFrame , AnyArrayLike ],
1436
+ ** kwargs
1437
+ ) -> None :
1407
1438
"""
1408
1439
Wrapper for tm.assert_*_equal to dispatch to the appropriate test function.
1409
1440
@@ -1415,27 +1446,36 @@ def assert_equal(left, right, **kwargs):
1415
1446
"""
1416
1447
__tracebackhide__ = True
1417
1448
1418
- if isinstance (left , pd .Index ):
1449
+ if isinstance (left , Index ):
1450
+ assert isinstance (right , Index )
1419
1451
assert_index_equal (left , right , ** kwargs )
1420
- elif isinstance (left , pd .Series ):
1452
+ elif isinstance (left , Series ):
1453
+ assert isinstance (right , Series )
1421
1454
assert_series_equal (left , right , ** kwargs )
1422
- elif isinstance (left , pd .DataFrame ):
1455
+ elif isinstance (left , DataFrame ):
1456
+ assert isinstance (right , DataFrame )
1423
1457
assert_frame_equal (left , right , ** kwargs )
1424
1458
elif isinstance (left , IntervalArray ):
1459
+ assert isinstance (right , IntervalArray )
1425
1460
assert_interval_array_equal (left , right , ** kwargs )
1426
1461
elif isinstance (left , PeriodArray ):
1462
+ assert isinstance (right , PeriodArray )
1427
1463
assert_period_array_equal (left , right , ** kwargs )
1428
1464
elif isinstance (left , DatetimeArray ):
1465
+ assert isinstance (right , DatetimeArray )
1429
1466
assert_datetime_array_equal (left , right , ** kwargs )
1430
1467
elif isinstance (left , TimedeltaArray ):
1468
+ assert isinstance (right , TimedeltaArray )
1431
1469
assert_timedelta_array_equal (left , right , ** kwargs )
1432
1470
elif isinstance (left , ExtensionArray ):
1471
+ assert isinstance (right , ExtensionArray )
1433
1472
assert_extension_array_equal (left , right , ** kwargs )
1434
1473
elif isinstance (left , np .ndarray ):
1474
+ assert isinstance (right , np .ndarray )
1435
1475
assert_numpy_array_equal (left , right , ** kwargs )
1436
1476
elif isinstance (left , str ):
1437
1477
assert kwargs == {}
1438
- return left == right
1478
+ assert left == right
1439
1479
else :
1440
1480
raise NotImplementedError (type (left ))
1441
1481
@@ -1497,12 +1537,12 @@ def to_array(obj):
1497
1537
1498
1538
1499
1539
def assert_sp_array_equal (
1500
- left ,
1501
- right ,
1502
- check_dtype = True ,
1503
- check_kind = True ,
1504
- check_fill_value = True ,
1505
- consolidate_block_indices = False ,
1540
+ left : pd . SparseArray ,
1541
+ right : pd . SparseArray ,
1542
+ check_dtype : bool = True ,
1543
+ check_kind : bool = True ,
1544
+ check_fill_value : bool = True ,
1545
+ consolidate_block_indices : bool = False ,
1506
1546
):
1507
1547
"""Check that the left and right SparseArray are equal.
1508
1548
0 commit comments