Skip to content

Commit 17caa42

Browse files
Add types to all assert_*_equal test functions
1 parent 3ca64e2 commit 17caa42

File tree

1 file changed

+79
-39
lines changed

1 file changed

+79
-39
lines changed

pandas/util/testing.py

+79-39
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from shutil import rmtree
99
import string
1010
import tempfile
11-
from typing import Union, cast
11+
from typing import Optional, Union, cast
1212
import warnings
1313
import zipfile
1414

@@ -53,6 +53,7 @@
5353
Series,
5454
bdate_range,
5555
)
56+
from pandas._typing import AnyArrayLike
5657
from pandas.core.algorithms import take_1d
5758
from pandas.core.arrays import (
5859
DatetimeArray,
@@ -806,8 +807,12 @@ def assert_is_sorted(seq):
806807

807808

808809
def assert_categorical_equal(
809-
left, right, check_dtype=True, check_category_order=True, obj="Categorical"
810-
):
810+
left: Categorical,
811+
right: Categorical,
812+
check_dtype: bool = True,
813+
check_category_order: bool = True,
814+
obj: str = "Categorical"
815+
) -> None:
811816
"""Test that Categoricals are equivalent.
812817
813818
Parameters
@@ -852,7 +857,12 @@ def assert_categorical_equal(
852857
assert_attr_equal("ordered", left, right, obj=obj)
853858

854859

855-
def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"):
860+
def assert_interval_array_equal(
861+
left: IntervalArray,
862+
right: IntervalArray,
863+
exact: str = "equiv",
864+
obj: str = "IntervalArray"
865+
) -> None:
856866
"""Test that two IntervalArrays are equivalent.
857867
858868
Parameters
@@ -878,7 +888,11 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
878888
assert_attr_equal("closed", left, right, obj=obj)
879889

880890

881-
def assert_period_array_equal(left, right, obj="PeriodArray"):
891+
def assert_period_array_equal(
892+
left: PeriodArray,
893+
right: PeriodArray,
894+
obj: str = "PeriodArray"
895+
) -> None:
882896
_check_isinstance(left, right, PeriodArray)
883897

884898
assert_numpy_array_equal(
@@ -887,7 +901,11 @@ def assert_period_array_equal(left, right, obj="PeriodArray"):
887901
assert_attr_equal("freq", left, right, obj=obj)
888902

889903

890-
def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
904+
def assert_datetime_array_equal(
905+
left: DatetimeArray,
906+
right: DatetimeArray,
907+
obj: str = "DatetimeArray"
908+
) -> None:
891909
__tracebackhide__ = True
892910
_check_isinstance(left, right, DatetimeArray)
893911

@@ -896,7 +914,11 @@ def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
896914
assert_attr_equal("tz", left, right, obj=obj)
897915

898916

899-
def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"):
917+
def assert_timedelta_array_equal(
918+
left: TimedeltaArray,
919+
right: TimedeltaArray,
920+
obj: str = "TimedeltaArray"
921+
) -> None:
900922
__tracebackhide__ = True
901923
_check_isinstance(left, right, TimedeltaArray)
902924
assert_numpy_array_equal(left._data, right._data, obj="{obj}._data".format(obj=obj))
@@ -931,13 +953,13 @@ def raise_assert_detail(obj, message, left, right, diff=None):
931953

932954

933955
def assert_numpy_array_equal(
934-
left,
935-
right,
936-
strict_nan=False,
937-
check_dtype=True,
938-
err_msg=None,
939-
check_same=None,
940-
obj="numpy array",
956+
left: np.ndarray,
957+
right: np.ndarray,
958+
strict_nan: bool = False,
959+
check_dtype: bool = True,
960+
err_msg: Optional[str] = None,
961+
check_same: Optional[str] = None,
962+
obj: str = "numpy array",
941963
):
942964
""" Checks that 'np.ndarray' is equivalent
943965
@@ -1067,18 +1089,18 @@ def assert_extension_array_equal(
10671089

10681090
# This could be refactored to use the NDFrame.equals method
10691091
def assert_series_equal(
1070-
left,
1071-
right,
1072-
check_dtype=True,
1073-
check_index_type="equiv",
1074-
check_series_type=True,
1075-
check_less_precise=False,
1076-
check_names=True,
1077-
check_exact=False,
1078-
check_datetimelike_compat=False,
1079-
check_categorical=True,
1080-
obj="Series",
1081-
):
1092+
left: Series,
1093+
right: Series,
1094+
check_dtype: bool = True,
1095+
check_index_type: str = "equiv",
1096+
check_series_type: bool = True,
1097+
check_less_precise: bool = False,
1098+
check_names: bool = True,
1099+
check_exact: bool = False,
1100+
check_datetimelike_compat: bool = False,
1101+
check_categorical: bool = True,
1102+
obj: str = "Series",
1103+
) -> None:
10821104
"""
10831105
Check that left and right Series are equal.
10841106
@@ -1185,8 +1207,13 @@ def assert_series_equal(
11851207
right._internal_get_values(),
11861208
check_dtype=check_dtype,
11871209
)
1188-
elif is_interval_dtype(left) or is_interval_dtype(right):
1189-
assert_interval_array_equal(left.array, right.array)
1210+
elif is_interval_dtype(left) or is_interval_dtype(left):
1211+
# must cast to interval dtype to keep mypy happy
1212+
assert is_interval_dtype(right)
1213+
assert is_interval_dtype(left)
1214+
left_array = IntervalArray(left.array)
1215+
right_array = IntervalArray(right.array)
1216+
assert_interval_array_equal(left_array, right_array)
11901217
elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype):
11911218
# .values is an ndarray, but ._values is the ExtensionArray.
11921219
# TODO: Use .array
@@ -1403,7 +1430,11 @@ def assert_frame_equal(
14031430
)
14041431

14051432

1406-
def assert_equal(left, right, **kwargs):
1433+
def assert_equal(
1434+
left: Union[DataFrame, AnyArrayLike],
1435+
right: Union[DataFrame, AnyArrayLike],
1436+
**kwargs
1437+
) -> None:
14071438
"""
14081439
Wrapper for tm.assert_*_equal to dispatch to the appropriate test function.
14091440
@@ -1415,27 +1446,36 @@ def assert_equal(left, right, **kwargs):
14151446
"""
14161447
__tracebackhide__ = True
14171448

1418-
if isinstance(left, pd.Index):
1449+
if isinstance(left, Index):
1450+
assert isinstance(right, Index)
14191451
assert_index_equal(left, right, **kwargs)
1420-
elif isinstance(left, pd.Series):
1452+
elif isinstance(left, Series):
1453+
assert isinstance(right, Series)
14211454
assert_series_equal(left, right, **kwargs)
1422-
elif isinstance(left, pd.DataFrame):
1455+
elif isinstance(left, DataFrame):
1456+
assert isinstance(right, DataFrame)
14231457
assert_frame_equal(left, right, **kwargs)
14241458
elif isinstance(left, IntervalArray):
1459+
assert isinstance(right, IntervalArray)
14251460
assert_interval_array_equal(left, right, **kwargs)
14261461
elif isinstance(left, PeriodArray):
1462+
assert isinstance(right, PeriodArray)
14271463
assert_period_array_equal(left, right, **kwargs)
14281464
elif isinstance(left, DatetimeArray):
1465+
assert isinstance(right, DatetimeArray)
14291466
assert_datetime_array_equal(left, right, **kwargs)
14301467
elif isinstance(left, TimedeltaArray):
1468+
assert isinstance(right, TimedeltaArray)
14311469
assert_timedelta_array_equal(left, right, **kwargs)
14321470
elif isinstance(left, ExtensionArray):
1471+
assert isinstance(right, ExtensionArray)
14331472
assert_extension_array_equal(left, right, **kwargs)
14341473
elif isinstance(left, np.ndarray):
1474+
assert isinstance(right, np.ndarray)
14351475
assert_numpy_array_equal(left, right, **kwargs)
14361476
elif isinstance(left, str):
14371477
assert kwargs == {}
1438-
return left == right
1478+
assert left == right
14391479
else:
14401480
raise NotImplementedError(type(left))
14411481

@@ -1497,12 +1537,12 @@ def to_array(obj):
14971537

14981538

14991539
def assert_sp_array_equal(
1500-
left,
1501-
right,
1502-
check_dtype=True,
1503-
check_kind=True,
1504-
check_fill_value=True,
1505-
consolidate_block_indices=False,
1540+
left: pd.SparseArray,
1541+
right: pd.SparseArray,
1542+
check_dtype: bool = True,
1543+
check_kind: bool = True,
1544+
check_fill_value: bool = True,
1545+
consolidate_block_indices: bool = False,
15061546
):
15071547
"""Check that the left and right SparseArray are equal.
15081548

0 commit comments

Comments
 (0)