-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
26302 add typing to assert star equal funcs #29364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
3ca64e2
17caa42
9d56cfc
f7392dd
daa9c87
0c2f692
657b3de
f9f4e7c
eb4c25a
7a2ae46
a735027
6d38c1b
e3b63c6
7268afa
9222488
ab95c8a
4a19f0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
from shutil import rmtree | ||
import string | ||
import tempfile | ||
from typing import Union, cast | ||
from typing import Optional, Union, cast | ||
import warnings | ||
import zipfile | ||
|
||
|
@@ -53,6 +53,7 @@ | |
Series, | ||
bdate_range, | ||
) | ||
from pandas._typing import AnyArrayLike | ||
from pandas.core.algorithms import take_1d | ||
from pandas.core.arrays import ( | ||
DatetimeArray, | ||
|
@@ -806,8 +807,12 @@ def assert_is_sorted(seq): | |
|
||
|
||
def assert_categorical_equal( | ||
left, right, check_dtype=True, check_category_order=True, obj="Categorical" | ||
): | ||
left: Categorical, | ||
right: Categorical, | ||
check_dtype: bool = True, | ||
check_category_order: bool = True, | ||
obj: str = "Categorical", | ||
) -> None: | ||
"""Test that Categoricals are equivalent. | ||
|
||
Parameters | ||
|
@@ -852,7 +857,12 @@ def assert_categorical_equal( | |
assert_attr_equal("ordered", left, right, obj=obj) | ||
|
||
|
||
def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"): | ||
def assert_interval_array_equal( | ||
left: IntervalArray, | ||
right: IntervalArray, | ||
exact: str = "equiv", | ||
obj: str = "IntervalArray", | ||
) -> None: | ||
"""Test that two IntervalArrays are equivalent. | ||
|
||
Parameters | ||
|
@@ -878,7 +888,9 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray") | |
assert_attr_equal("closed", left, right, obj=obj) | ||
|
||
|
||
def assert_period_array_equal(left, right, obj="PeriodArray"): | ||
def assert_period_array_equal( | ||
left: PeriodArray, right: PeriodArray, obj: str = "PeriodArray" | ||
) -> None: | ||
_check_isinstance(left, right, PeriodArray) | ||
|
||
assert_numpy_array_equal( | ||
|
@@ -887,7 +899,9 @@ def assert_period_array_equal(left, right, obj="PeriodArray"): | |
assert_attr_equal("freq", left, right, obj=obj) | ||
|
||
|
||
def assert_datetime_array_equal(left, right, obj="DatetimeArray"): | ||
def assert_datetime_array_equal( | ||
left: DatetimeArray, right: DatetimeArray, obj: str = "DatetimeArray" | ||
) -> None: | ||
__tracebackhide__ = True | ||
_check_isinstance(left, right, DatetimeArray) | ||
|
||
|
@@ -896,7 +910,9 @@ def assert_datetime_array_equal(left, right, obj="DatetimeArray"): | |
assert_attr_equal("tz", left, right, obj=obj) | ||
|
||
|
||
def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"): | ||
def assert_timedelta_array_equal( | ||
left: TimedeltaArray, right: TimedeltaArray, obj: str = "TimedeltaArray" | ||
) -> None: | ||
__tracebackhide__ = True | ||
_check_isinstance(left, right, TimedeltaArray) | ||
assert_numpy_array_equal(left._data, right._data, obj="{obj}._data".format(obj=obj)) | ||
|
@@ -931,13 +947,13 @@ def raise_assert_detail(obj, message, left, right, diff=None): | |
|
||
|
||
def assert_numpy_array_equal( | ||
left, | ||
right, | ||
strict_nan=False, | ||
check_dtype=True, | ||
err_msg=None, | ||
check_same=None, | ||
obj="numpy array", | ||
left: np.ndarray, | ||
right: np.ndarray, | ||
strict_nan: bool = False, | ||
check_dtype: bool = True, | ||
err_msg: Optional[str] = None, | ||
check_same: Optional[str] = None, | ||
obj: str = "numpy array", | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could annotate return of None here as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you update this one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
""" Checks that 'np.ndarray' is equivalent | ||
|
||
|
@@ -1067,18 +1083,18 @@ def assert_extension_array_equal( | |
|
||
# This could be refactored to use the NDFrame.equals method | ||
def assert_series_equal( | ||
left, | ||
right, | ||
check_dtype=True, | ||
check_index_type="equiv", | ||
check_series_type=True, | ||
check_less_precise=False, | ||
check_names=True, | ||
check_exact=False, | ||
check_datetimelike_compat=False, | ||
check_categorical=True, | ||
obj="Series", | ||
): | ||
left: Series, | ||
right: Series, | ||
check_dtype: bool = True, | ||
check_index_type: str = "equiv", | ||
check_series_type: bool = True, | ||
check_less_precise: bool = False, | ||
check_names: bool = True, | ||
check_exact: bool = False, | ||
check_datetimelike_compat: bool = False, | ||
check_categorical: bool = True, | ||
obj: str = "Series", | ||
) -> None: | ||
""" | ||
Check that left and right Series are equal. | ||
|
||
|
@@ -1185,8 +1201,13 @@ def assert_series_equal( | |
right._internal_get_values(), | ||
check_dtype=check_dtype, | ||
) | ||
elif is_interval_dtype(left) or is_interval_dtype(right): | ||
assert_interval_array_equal(left.array, right.array) | ||
elif is_interval_dtype(left) or is_interval_dtype(left): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is now just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, I think change was my mistake. Have removed it. |
||
# must cast to interval dtype to keep mypy happy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What was the complaint on this? This changes the actual assertions being done I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I don't cast to IntervalArray I get these errors
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've removed the asserts though, they weren't needed here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned by @jreback should use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can delete this comment |
||
assert is_interval_dtype(right) | ||
assert is_interval_dtype(left) | ||
left_array = IntervalArray(left.array) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah we don't want to do this (cast is ok), but don't actually coerce with a constructor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ha got it, wasn't aware of typing.cast. Have updated the code. |
||
right_array = IntervalArray(right.array) | ||
assert_interval_array_equal(left_array, right_array) | ||
elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype): | ||
# .values is an ndarray, but ._values is the ExtensionArray. | ||
# TODO: Use .array | ||
|
@@ -1221,21 +1242,21 @@ def assert_series_equal( | |
|
||
# This could be refactored to use the NDFrame.equals method | ||
def assert_frame_equal( | ||
left, | ||
right, | ||
check_dtype=True, | ||
check_index_type="equiv", | ||
check_column_type="equiv", | ||
check_frame_type=True, | ||
check_less_precise=False, | ||
check_names=True, | ||
by_blocks=False, | ||
check_exact=False, | ||
check_datetimelike_compat=False, | ||
check_categorical=True, | ||
check_like=False, | ||
obj="DataFrame", | ||
): | ||
left: DataFrame, | ||
right: DataFrame, | ||
check_dtype: bool = True, | ||
check_index_type: str = "equiv", | ||
check_column_type: str = "equiv", | ||
check_frame_type: bool = True, | ||
check_less_precise: bool = False, | ||
check_names: bool = True, | ||
by_blocks: bool = False, | ||
check_exact: bool = False, | ||
check_datetimelike_compat: bool = False, | ||
check_categorical: bool = True, | ||
check_like: bool = False, | ||
obj: str = "DataFrame", | ||
) -> None: | ||
""" | ||
Check that left and right DataFrame are equal. | ||
|
||
|
@@ -1403,7 +1424,11 @@ def assert_frame_equal( | |
) | ||
|
||
|
||
def assert_equal(left, right, **kwargs): | ||
def assert_equal( | ||
left: Union[DataFrame, AnyArrayLike], | ||
WillAyd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
right: Union[DataFrame, AnyArrayLike], | ||
**kwargs | ||
) -> None: | ||
""" | ||
Wrapper for tm.assert_*_equal to dispatch to the appropriate test function. | ||
|
||
|
@@ -1415,27 +1440,36 @@ def assert_equal(left, right, **kwargs): | |
""" | ||
__tracebackhide__ = True | ||
|
||
if isinstance(left, pd.Index): | ||
if isinstance(left, Index): | ||
assert isinstance(right, Index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the message here? These changes also impact the tests so don't think we want these There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah we don't want to actually change the semantics here, remove these There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I comment out
But assuming left is an index, assert_index_equal(left, right) should raise if right isn't an index, so I don't think adding an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK I've just pushed another attempt (0c2f692) that uses the |
||
assert_index_equal(left, right, **kwargs) | ||
elif isinstance(left, pd.Series): | ||
elif isinstance(left, Series): | ||
assert isinstance(right, Series) | ||
assert_series_equal(left, right, **kwargs) | ||
elif isinstance(left, pd.DataFrame): | ||
elif isinstance(left, DataFrame): | ||
assert isinstance(right, DataFrame) | ||
assert_frame_equal(left, right, **kwargs) | ||
elif isinstance(left, IntervalArray): | ||
assert isinstance(right, IntervalArray) | ||
assert_interval_array_equal(left, right, **kwargs) | ||
elif isinstance(left, PeriodArray): | ||
assert isinstance(right, PeriodArray) | ||
assert_period_array_equal(left, right, **kwargs) | ||
elif isinstance(left, DatetimeArray): | ||
assert isinstance(right, DatetimeArray) | ||
assert_datetime_array_equal(left, right, **kwargs) | ||
elif isinstance(left, TimedeltaArray): | ||
assert isinstance(right, TimedeltaArray) | ||
assert_timedelta_array_equal(left, right, **kwargs) | ||
elif isinstance(left, ExtensionArray): | ||
assert isinstance(right, ExtensionArray) | ||
assert_extension_array_equal(left, right, **kwargs) | ||
elif isinstance(left, np.ndarray): | ||
assert isinstance(right, np.ndarray) | ||
assert_numpy_array_equal(left, right, **kwargs) | ||
elif isinstance(left, str): | ||
assert kwargs == {} | ||
return left == right | ||
assert left == right | ||
WillAyd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise NotImplementedError(type(left)) | ||
|
||
|
@@ -1497,12 +1531,12 @@ def to_array(obj): | |
|
||
|
||
def assert_sp_array_equal( | ||
left, | ||
right, | ||
check_dtype=True, | ||
check_kind=True, | ||
check_fill_value=True, | ||
consolidate_block_indices=False, | ||
left: pd.SparseArray, | ||
right: pd.SparseArray, | ||
check_dtype: bool = True, | ||
check_kind: bool = True, | ||
check_fill_value: bool = True, | ||
consolidate_block_indices: bool = False, | ||
): | ||
"""Check that the left and right SparseArray are equal. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AnyArrayLike resolves to Any.
If it adds value in the form of code documentation then OK, but mypy is effectively not checking these annotations.