Skip to content

26302 add typing to assert star equal funcs #29364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 88 additions & 54 deletions pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from shutil import rmtree
import string
import tempfile
from typing import Union, cast
from typing import Optional, Union, cast
import warnings
import zipfile

Expand Down Expand Up @@ -53,6 +53,7 @@
Series,
bdate_range,
)
from pandas._typing import AnyArrayLike
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AnyArrayLike resolves to Any.

If it adds value in the form of code documentation then OK, but mypy is effectively not checking these annotations.

from pandas.core.algorithms import take_1d
from pandas.core.arrays import (
DatetimeArray,
Expand Down Expand Up @@ -806,8 +807,12 @@ def assert_is_sorted(seq):


def assert_categorical_equal(
left, right, check_dtype=True, check_category_order=True, obj="Categorical"
):
left: Categorical,
right: Categorical,
check_dtype: bool = True,
check_category_order: bool = True,
obj: str = "Categorical",
) -> None:
"""Test that Categoricals are equivalent.

Parameters
Expand Down Expand Up @@ -852,7 +857,12 @@ def assert_categorical_equal(
assert_attr_equal("ordered", left, right, obj=obj)


def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"):
def assert_interval_array_equal(
left: IntervalArray,
right: IntervalArray,
exact: str = "equiv",
obj: str = "IntervalArray",
) -> None:
"""Test that two IntervalArrays are equivalent.

Parameters
Expand All @@ -878,7 +888,9 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
assert_attr_equal("closed", left, right, obj=obj)


def assert_period_array_equal(left, right, obj="PeriodArray"):
def assert_period_array_equal(
left: PeriodArray, right: PeriodArray, obj: str = "PeriodArray"
) -> None:
_check_isinstance(left, right, PeriodArray)

assert_numpy_array_equal(
Expand All @@ -887,7 +899,9 @@ def assert_period_array_equal(left, right, obj="PeriodArray"):
assert_attr_equal("freq", left, right, obj=obj)


def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
def assert_datetime_array_equal(
left: DatetimeArray, right: DatetimeArray, obj: str = "DatetimeArray"
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, DatetimeArray)

Expand All @@ -896,7 +910,9 @@ def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
assert_attr_equal("tz", left, right, obj=obj)


def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"):
def assert_timedelta_array_equal(
left: TimedeltaArray, right: TimedeltaArray, obj: str = "TimedeltaArray"
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, TimedeltaArray)
assert_numpy_array_equal(left._data, right._data, obj="{obj}._data".format(obj=obj))
Expand Down Expand Up @@ -931,13 +947,13 @@ def raise_assert_detail(obj, message, left, right, diff=None):


def assert_numpy_array_equal(
left,
right,
strict_nan=False,
check_dtype=True,
err_msg=None,
check_same=None,
obj="numpy array",
left: np.ndarray,
right: np.ndarray,
strict_nan: bool = False,
check_dtype: bool = True,
err_msg: Optional[str] = None,
check_same: Optional[str] = None,
obj: str = "numpy array",
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could annotate return of None here as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

""" Checks that 'np.ndarray' is equivalent

Expand Down Expand Up @@ -1067,18 +1083,18 @@ def assert_extension_array_equal(

# This could be refactored to use the NDFrame.equals method
def assert_series_equal(
left,
right,
check_dtype=True,
check_index_type="equiv",
check_series_type=True,
check_less_precise=False,
check_names=True,
check_exact=False,
check_datetimelike_compat=False,
check_categorical=True,
obj="Series",
):
left: Series,
right: Series,
check_dtype: bool = True,
check_index_type: str = "equiv",
check_series_type: bool = True,
check_less_precise: bool = False,
check_names: bool = True,
check_exact: bool = False,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
obj: str = "Series",
) -> None:
"""
Check that left and right Series are equal.

Expand Down Expand Up @@ -1185,8 +1201,13 @@ def assert_series_equal(
right._internal_get_values(),
check_dtype=check_dtype,
)
elif is_interval_dtype(left) or is_interval_dtype(right):
assert_interval_array_equal(left.array, right.array)
elif is_interval_dtype(left) or is_interval_dtype(left):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now just elif is_interval_dtype(left) or should the second condition not have been changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I think change was my mistake. Have removed it.

# must cast to interval dtype to keep mypy happy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the complaint on this? This changes the actual assertions being done I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't cast to IntervalArray I get these errors

pandas/util/testing.py:1211: error: Argument 1 to "assert_interval_array_equal" has incompatible type "ExtensionArray"; expected "IntervalArray"
pandas/util/testing.py:1211: error: Argument 2 to "assert_interval_array_equal" has incompatible type "ExtensionArray"; expected "IntervalArray"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the asserts though, they weren't needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned by @jreback should use cast from the typing module here - don't want to actually construct a new object via a call to IntervalArray

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can delete this comment

assert is_interval_dtype(right)
assert is_interval_dtype(left)
left_array = IntervalArray(left.array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we don't want to do this (cast is ok), but don't actually coerce with a constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha got it, wasn't aware of typing.cast. Have updated the code.

right_array = IntervalArray(right.array)
assert_interval_array_equal(left_array, right_array)
elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype):
# .values is an ndarray, but ._values is the ExtensionArray.
# TODO: Use .array
Expand Down Expand Up @@ -1221,21 +1242,21 @@ def assert_series_equal(

# This could be refactored to use the NDFrame.equals method
def assert_frame_equal(
left,
right,
check_dtype=True,
check_index_type="equiv",
check_column_type="equiv",
check_frame_type=True,
check_less_precise=False,
check_names=True,
by_blocks=False,
check_exact=False,
check_datetimelike_compat=False,
check_categorical=True,
check_like=False,
obj="DataFrame",
):
left: DataFrame,
right: DataFrame,
check_dtype: bool = True,
check_index_type: str = "equiv",
check_column_type: str = "equiv",
check_frame_type: bool = True,
check_less_precise: bool = False,
check_names: bool = True,
by_blocks: bool = False,
check_exact: bool = False,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
check_like: bool = False,
obj: str = "DataFrame",
) -> None:
"""
Check that left and right DataFrame are equal.

Expand Down Expand Up @@ -1403,7 +1424,11 @@ def assert_frame_equal(
)


def assert_equal(left, right, **kwargs):
def assert_equal(
left: Union[DataFrame, AnyArrayLike],
right: Union[DataFrame, AnyArrayLike],
**kwargs
) -> None:
"""
Wrapper for tm.assert_*_equal to dispatch to the appropriate test function.

Expand All @@ -1415,27 +1440,36 @@ def assert_equal(left, right, **kwargs):
"""
__tracebackhide__ = True

if isinstance(left, pd.Index):
if isinstance(left, Index):
assert isinstance(right, Index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the message here? These changes also impact the tests so don't think we want these

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we don't want to actually change the semantics here, remove these

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I comment out assert isinstance(right, Index) then mypy complains about assert_index_equal(left, right, **kwargs):

pandas/util/testing.py:1443: error: Argument 2 to "assert_index_equal" has incompatible type "Union[DataFrame, Index]"; expected "Index"
pandas/util/testing.py:1443: error: Argument 2 to "assert_index_equal" has incompatible type "Union[DataFrame, Any]"; expected "Index"

But assuming left is an index, assert_index_equal(left, right) should raise if right isn't an index, so I don't think adding an assert isinstance(right, index) changes the semantics...? We will just catch the exception earlier based on the type.
I got the idea from @simonjayhawkins but I might have misunderstood.
Happy to do something different if there's a better way of handling this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I've just pushed another attempt (0c2f692) that uses the _check_isinstance(left, right, Index) command, which is always called by assert_index_equal. There's no change in semantics as far as I can tell. If that line doesn't raise then it's safe to case right to Index which keeps mypy happy.
If this is OK I can try and do something similar for the other types in assert_equal. Please let me know what you think.

assert_index_equal(left, right, **kwargs)
elif isinstance(left, pd.Series):
elif isinstance(left, Series):
assert isinstance(right, Series)
assert_series_equal(left, right, **kwargs)
elif isinstance(left, pd.DataFrame):
elif isinstance(left, DataFrame):
assert isinstance(right, DataFrame)
assert_frame_equal(left, right, **kwargs)
elif isinstance(left, IntervalArray):
assert isinstance(right, IntervalArray)
assert_interval_array_equal(left, right, **kwargs)
elif isinstance(left, PeriodArray):
assert isinstance(right, PeriodArray)
assert_period_array_equal(left, right, **kwargs)
elif isinstance(left, DatetimeArray):
assert isinstance(right, DatetimeArray)
assert_datetime_array_equal(left, right, **kwargs)
elif isinstance(left, TimedeltaArray):
assert isinstance(right, TimedeltaArray)
assert_timedelta_array_equal(left, right, **kwargs)
elif isinstance(left, ExtensionArray):
assert isinstance(right, ExtensionArray)
assert_extension_array_equal(left, right, **kwargs)
elif isinstance(left, np.ndarray):
assert isinstance(right, np.ndarray)
assert_numpy_array_equal(left, right, **kwargs)
elif isinstance(left, str):
assert kwargs == {}
return left == right
assert left == right
else:
raise NotImplementedError(type(left))

Expand Down Expand Up @@ -1497,12 +1531,12 @@ def to_array(obj):


def assert_sp_array_equal(
left,
right,
check_dtype=True,
check_kind=True,
check_fill_value=True,
consolidate_block_indices=False,
left: pd.SparseArray,
right: pd.SparseArray,
check_dtype: bool = True,
check_kind: bool = True,
check_fill_value: bool = True,
consolidate_block_indices: bool = False,
):
"""Check that the left and right SparseArray are equal.

Expand Down