-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
26302 add typing to assert star equal funcs #29364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
3ca64e2
17caa42
9d56cfc
f7392dd
daa9c87
0c2f692
657b3de
f9f4e7c
eb4c25a
7a2ae46
a735027
6d38c1b
e3b63c6
7268afa
9222488
ab95c8a
4a19f0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
from shutil import rmtree | ||
import string | ||
import tempfile | ||
from typing import Union, cast | ||
from typing import Optional, Union, cast | ||
import warnings | ||
import zipfile | ||
|
||
|
@@ -53,6 +53,7 @@ | |
Series, | ||
bdate_range, | ||
) | ||
from pandas._typing import AnyArrayLike | ||
from pandas.core.algorithms import take_1d | ||
from pandas.core.arrays import ( | ||
DatetimeArray, | ||
|
@@ -806,8 +807,12 @@ def assert_is_sorted(seq): | |
|
||
|
||
def assert_categorical_equal( | ||
left, right, check_dtype=True, check_category_order=True, obj="Categorical" | ||
): | ||
left: Categorical, | ||
right: Categorical, | ||
check_dtype: bool = True, | ||
check_category_order: bool = True, | ||
obj: str = "Categorical", | ||
) -> None: | ||
"""Test that Categoricals are equivalent. | ||
|
||
Parameters | ||
|
@@ -852,7 +857,12 @@ def assert_categorical_equal( | |
assert_attr_equal("ordered", left, right, obj=obj) | ||
|
||
|
||
def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"): | ||
def assert_interval_array_equal( | ||
left: IntervalArray, | ||
right: IntervalArray, | ||
exact: str = "equiv", | ||
obj: str = "IntervalArray", | ||
) -> None: | ||
"""Test that two IntervalArrays are equivalent. | ||
|
||
Parameters | ||
|
@@ -867,18 +877,20 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray") | |
Specify object name being compared, internally used to show appropriate | ||
assertion message | ||
""" | ||
_check_isinstance(left, right, IntervalArray) | ||
|
||
assert_index_equal( | ||
left.left, right.left, exact=exact, obj="{obj}.left".format(obj=obj) | ||
) | ||
assert_index_equal( | ||
left.right, right.right, exact=exact, obj="{obj}.left".format(obj=obj) | ||
) | ||
left = cast(IntervalArray, left) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these casts required? Seems like they only subsequently get sent to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these casts are not required here.
|
||
right = cast(IntervalArray, right) | ||
assert_attr_equal("closed", left, right, obj=obj) | ||
|
||
|
||
def assert_period_array_equal(left, right, obj="PeriodArray"): | ||
def assert_period_array_equal( | ||
left: PeriodArray, right: PeriodArray, obj: str = "PeriodArray" | ||
) -> None: | ||
_check_isinstance(left, right, PeriodArray) | ||
|
||
assert_numpy_array_equal( | ||
|
@@ -887,7 +899,9 @@ def assert_period_array_equal(left, right, obj="PeriodArray"): | |
assert_attr_equal("freq", left, right, obj=obj) | ||
|
||
|
||
def assert_datetime_array_equal(left, right, obj="DatetimeArray"): | ||
def assert_datetime_array_equal( | ||
left: DatetimeArray, right: DatetimeArray, obj: str = "DatetimeArray" | ||
) -> None: | ||
__tracebackhide__ = True | ||
_check_isinstance(left, right, DatetimeArray) | ||
|
||
|
@@ -896,7 +910,9 @@ def assert_datetime_array_equal(left, right, obj="DatetimeArray"): | |
assert_attr_equal("tz", left, right, obj=obj) | ||
|
||
|
||
def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"): | ||
def assert_timedelta_array_equal( | ||
left: TimedeltaArray, right: TimedeltaArray, obj: str = "TimedeltaArray" | ||
) -> None: | ||
__tracebackhide__ = True | ||
_check_isinstance(left, right, TimedeltaArray) | ||
assert_numpy_array_equal(left._data, right._data, obj="{obj}._data".format(obj=obj)) | ||
|
@@ -931,14 +947,14 @@ def raise_assert_detail(obj, message, left, right, diff=None): | |
|
||
|
||
def assert_numpy_array_equal( | ||
left, | ||
right, | ||
strict_nan=False, | ||
check_dtype=True, | ||
err_msg=None, | ||
check_same=None, | ||
obj="numpy array", | ||
): | ||
left: np.ndarray, | ||
right: np.ndarray, | ||
strict_nan: bool = False, | ||
check_dtype: bool = True, | ||
err_msg: Optional[str] = None, | ||
check_same: Optional[str] = None, | ||
obj: str = "numpy array", | ||
) -> None: | ||
""" Checks that 'np.ndarray' is equivalent | ||
|
||
Parameters | ||
|
@@ -1067,18 +1083,18 @@ def assert_extension_array_equal( | |
|
||
# This could be refactored to use the NDFrame.equals method | ||
def assert_series_equal( | ||
left, | ||
right, | ||
check_dtype=True, | ||
check_index_type="equiv", | ||
check_series_type=True, | ||
check_less_precise=False, | ||
check_names=True, | ||
check_exact=False, | ||
check_datetimelike_compat=False, | ||
check_categorical=True, | ||
obj="Series", | ||
): | ||
left: Series, | ||
right: Series, | ||
check_dtype: bool = True, | ||
check_index_type: str = "equiv", | ||
check_series_type: bool = True, | ||
check_less_precise: bool = False, | ||
check_names: bool = True, | ||
check_exact: bool = False, | ||
check_datetimelike_compat: bool = False, | ||
check_categorical: bool = True, | ||
obj: str = "Series", | ||
) -> None: | ||
""" | ||
Check that left and right Series are equal. | ||
|
||
|
@@ -1185,8 +1201,11 @@ def assert_series_equal( | |
right._internal_get_values(), | ||
check_dtype=check_dtype, | ||
) | ||
elif is_interval_dtype(left) or is_interval_dtype(right): | ||
assert_interval_array_equal(left.array, right.array) | ||
elif is_interval_dtype(left) or is_interval_dtype(left): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is now just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, I think change was my mistake. Have removed it. |
||
# must cast to interval dtype to keep mypy happy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What was the complaint on this? This changes the actual assertions being done I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I don't cast to IntervalArray I get these errors
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've removed the asserts though, they weren't needed here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned by @jreback should use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can delete this comment |
||
left_array = cast(IntervalArray, left.array) | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
right_array = cast(IntervalArray, right.array) | ||
assert_interval_array_equal(left_array, right_array) | ||
elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype): | ||
# .values is an ndarray, but ._values is the ExtensionArray. | ||
# TODO: Use .array | ||
|
@@ -1221,21 +1240,21 @@ def assert_series_equal( | |
|
||
# This could be refactored to use the NDFrame.equals method | ||
def assert_frame_equal( | ||
left, | ||
right, | ||
check_dtype=True, | ||
check_index_type="equiv", | ||
check_column_type="equiv", | ||
check_frame_type=True, | ||
check_less_precise=False, | ||
check_names=True, | ||
by_blocks=False, | ||
check_exact=False, | ||
check_datetimelike_compat=False, | ||
check_categorical=True, | ||
check_like=False, | ||
obj="DataFrame", | ||
): | ||
left: DataFrame, | ||
right: DataFrame, | ||
check_dtype: bool = True, | ||
check_index_type: str = "equiv", | ||
check_column_type: str = "equiv", | ||
check_frame_type: bool = True, | ||
check_less_precise: bool = False, | ||
check_names: bool = True, | ||
by_blocks: bool = False, | ||
check_exact: bool = False, | ||
check_datetimelike_compat: bool = False, | ||
check_categorical: bool = True, | ||
check_like: bool = False, | ||
obj: str = "DataFrame", | ||
) -> None: | ||
""" | ||
Check that left and right DataFrame are equal. | ||
|
||
|
@@ -1403,7 +1422,11 @@ def assert_frame_equal( | |
) | ||
|
||
|
||
def assert_equal(left, right, **kwargs): | ||
def assert_equal( | ||
left: Union[DataFrame, AnyArrayLike], | ||
WillAyd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
right: Union[DataFrame, AnyArrayLike], | ||
**kwargs | ||
) -> None: | ||
""" | ||
Wrapper for tm.assert_*_equal to dispatch to the appropriate test function. | ||
|
||
|
@@ -1415,27 +1438,36 @@ def assert_equal(left, right, **kwargs): | |
""" | ||
__tracebackhide__ = True | ||
|
||
if isinstance(left, pd.Index): | ||
if isinstance(left, Index): | ||
right = cast(Index, right) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't need any of the casts here; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if I remove the cast I get the following error.
|
||
assert_index_equal(left, right, **kwargs) | ||
elif isinstance(left, pd.Series): | ||
elif isinstance(left, Series): | ||
right = cast(Series, right) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these casts shouldn't be added here. casts are for when we know better than mypy. in these cases the second argument to assert_series_equal etc may be anything. since the is_instance checks are performed in assert_series_equal etc. it may be better to type all the the assert functions as
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These functions do expect certain types (as described in the docstring), so I would think we'd want to be more explicit here, even if we need to add a few casts. These casts aren't happening at runtime anyway, right? https://docs.python.org/3/library/typing.html#typing.cast There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I thing the docstrings are wrong. The correct behaviour is that the assertion helpers take any two objects and raise an AssertionError if the objects are of the wrong type and not equal. If the objects passed were expected to be of a certain type I would expect to get a TypeError. >>> import pandas as pd
>>> import pandas.util.testing as tm
>>> pd.__version__
'0.26.0.dev0+984.g5ba71f883.dirty'
>>>
>>> tm.assert_frame_equal(pd.DataFrame(), pd.Series())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\Users\simon\pandas\pandas\util\testing.py", line 1330, in assert_frame_equal
_check_isinstance(left, right, DataFrame)
File "C:\Users\simon\pandas\pandas\util\testing.py", line 392, in _check_isinstance
err_msg.format(name=cls_name, exp_type=cls, act_type=type(right))
AssertionError: DataFrame Expected type <class 'pandas.core.frame.DataFrame'>, found <class 'pandas.core.series.Series'> instead
>>>
>>> tm.assert_frame_equal(pd.DataFrame(), 'foo')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\Users\simon\pandas\pandas\util\testing.py", line 1330, in assert_frame_equal
_check_isinstance(left, right, DataFrame)
File "C:\Users\simon\pandas\pandas\util\testing.py", line 392, in _check_isinstance
err_msg.format(name=cls_name, exp_type=cls, act_type=type(right))
AssertionError: DataFrame Expected type <class 'pandas.core.frame.DataFrame'>, found <class 'str'> instead
>>>
>>> tm.assert_frame_equal(pd.Series(), pd.Series())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\Users\simon\pandas\pandas\util\testing.py", line 1330, in assert_frame_equal
_check_isinstance(left, right, DataFrame)
File "C:\Users\simon\pandas\pandas\util\testing.py", line 388, in _check_isinstance
err_msg.format(name=cls_name, exp_type=cls, act_type=type(left))
AssertionError: DataFrame Expected type <class 'pandas.core.frame.DataFrame'>, found <class 'pandas.core.series.Series'> instead
>>> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that if they are not typed as Any then this could trip up users who depend on this instance checking. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you suggesting a change to the implementation here? When I call an assert_* function I would not expect it to raise something other than an assertion error. So I wouldn't expect to get a TypeError. For example if I compare
Am I missing something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An alternate approach here may be to only type the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Otherwise we could narrow the types of both left and right and fall through to a failure message at the end that they aren’t of the same type, but maybe better inspected as a follow up; not sure we want that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this PR I think it makes sense to follow the existing implementation as described in the docstrings. If you call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about the following code
Is that valid code? IIUC mypy will complain about it, since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well for me that's not valid code, I would expect this to be caught by mypy since we're comparing a series to an array. But my understanding is that @simonjayhawkins thinkgs we should leave mypy out of this and allow any object in these functions. |
||
assert_series_equal(left, right, **kwargs) | ||
elif isinstance(left, pd.DataFrame): | ||
elif isinstance(left, DataFrame): | ||
right = cast(DataFrame, right) | ||
assert_frame_equal(left, right, **kwargs) | ||
elif isinstance(left, IntervalArray): | ||
right = cast(IntervalArray, right) | ||
assert_interval_array_equal(left, right, **kwargs) | ||
elif isinstance(left, PeriodArray): | ||
right = cast(PeriodArray, right) | ||
assert_period_array_equal(left, right, **kwargs) | ||
elif isinstance(left, DatetimeArray): | ||
right = cast(DatetimeArray, right) | ||
assert_datetime_array_equal(left, right, **kwargs) | ||
elif isinstance(left, TimedeltaArray): | ||
right = cast(TimedeltaArray, right) | ||
assert_timedelta_array_equal(left, right, **kwargs) | ||
elif isinstance(left, ExtensionArray): | ||
right = cast(ExtensionArray, right) | ||
assert_extension_array_equal(left, right, **kwargs) | ||
elif isinstance(left, np.ndarray): | ||
right = cast(np.ndarray, right) | ||
assert_numpy_array_equal(left, right, **kwargs) | ||
elif isinstance(left, str): | ||
assert kwargs == {} | ||
return left == right | ||
assert left == right | ||
WillAyd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise NotImplementedError(type(left)) | ||
|
||
|
@@ -1497,12 +1529,12 @@ def to_array(obj): | |
|
||
|
||
def assert_sp_array_equal( | ||
left, | ||
right, | ||
check_dtype=True, | ||
check_kind=True, | ||
check_fill_value=True, | ||
consolidate_block_indices=False, | ||
left: pd.SparseArray, | ||
right: pd.SparseArray, | ||
check_dtype: bool = True, | ||
check_kind: bool = True, | ||
check_fill_value: bool = True, | ||
consolidate_block_indices: bool = False, | ||
): | ||
"""Check that the left and right SparseArray are equal. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AnyArrayLike resolves to Any.
If it adds value in the form of code documentation then OK, but mypy is effectively not checking these annotations.