Skip to content

26302 add typing to assert star equal funcs #29364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 89 additions & 57 deletions pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from shutil import rmtree
import string
import tempfile
from typing import Union, cast
from typing import Optional, Union, cast
import warnings
import zipfile

Expand Down Expand Up @@ -53,6 +53,7 @@
Series,
bdate_range,
)
from pandas._typing import AnyArrayLike
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AnyArrayLike resolves to Any.

If it adds value in the form of code documentation then OK, but mypy is effectively not checking these annotations.

from pandas.core.algorithms import take_1d
from pandas.core.arrays import (
DatetimeArray,
Expand Down Expand Up @@ -806,8 +807,12 @@ def assert_is_sorted(seq):


def assert_categorical_equal(
left, right, check_dtype=True, check_category_order=True, obj="Categorical"
):
left: Categorical,
right: Categorical,
check_dtype: bool = True,
check_category_order: bool = True,
obj: str = "Categorical",
) -> None:
"""Test that Categoricals are equivalent.

Parameters
Expand Down Expand Up @@ -852,7 +857,12 @@ def assert_categorical_equal(
assert_attr_equal("ordered", left, right, obj=obj)


def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"):
def assert_interval_array_equal(
left: IntervalArray,
right: IntervalArray,
exact: str = "equiv",
obj: str = "IntervalArray",
) -> None:
"""Test that two IntervalArrays are equivalent.

Parameters
Expand All @@ -867,18 +877,20 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
Specify object name being compared, internally used to show appropriate
assertion message
"""
_check_isinstance(left, right, IntervalArray)

assert_index_equal(
left.left, right.left, exact=exact, obj="{obj}.left".format(obj=obj)
)
assert_index_equal(
left.right, right.right, exact=exact, obj="{obj}.left".format(obj=obj)
)
left = cast(IntervalArray, left)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these casts required? Seems like they only subsequently get sent to assert_attr_equal which is untyped, so not sure what the failure would be (may also be overlooking)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these casts are not required here.

$ mypy pandas --warn-redundant-casts
pandas\util\testing.py:886: error: Redundant cast to "IntervalArray"
pandas\util\testing.py:887: error: Redundant cast to "IntervalArray"

right = cast(IntervalArray, right)
assert_attr_equal("closed", left, right, obj=obj)


def assert_period_array_equal(left, right, obj="PeriodArray"):
def assert_period_array_equal(
left: PeriodArray, right: PeriodArray, obj: str = "PeriodArray"
) -> None:
_check_isinstance(left, right, PeriodArray)

assert_numpy_array_equal(
Expand All @@ -887,7 +899,9 @@ def assert_period_array_equal(left, right, obj="PeriodArray"):
assert_attr_equal("freq", left, right, obj=obj)


def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
def assert_datetime_array_equal(
left: DatetimeArray, right: DatetimeArray, obj: str = "DatetimeArray"
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, DatetimeArray)

Expand All @@ -896,7 +910,9 @@ def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
assert_attr_equal("tz", left, right, obj=obj)


def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"):
def assert_timedelta_array_equal(
left: TimedeltaArray, right: TimedeltaArray, obj: str = "TimedeltaArray"
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, TimedeltaArray)
assert_numpy_array_equal(left._data, right._data, obj="{obj}._data".format(obj=obj))
Expand Down Expand Up @@ -931,14 +947,14 @@ def raise_assert_detail(obj, message, left, right, diff=None):


def assert_numpy_array_equal(
left,
right,
strict_nan=False,
check_dtype=True,
err_msg=None,
check_same=None,
obj="numpy array",
):
left: np.ndarray,
right: np.ndarray,
strict_nan: bool = False,
check_dtype: bool = True,
err_msg: Optional[str] = None,
check_same: Optional[str] = None,
obj: str = "numpy array",
) -> None:
""" Checks that 'np.ndarray' is equivalent

Parameters
Expand Down Expand Up @@ -1067,18 +1083,18 @@ def assert_extension_array_equal(

# This could be refactored to use the NDFrame.equals method
def assert_series_equal(
left,
right,
check_dtype=True,
check_index_type="equiv",
check_series_type=True,
check_less_precise=False,
check_names=True,
check_exact=False,
check_datetimelike_compat=False,
check_categorical=True,
obj="Series",
):
left: Series,
right: Series,
check_dtype: bool = True,
check_index_type: str = "equiv",
check_series_type: bool = True,
check_less_precise: bool = False,
check_names: bool = True,
check_exact: bool = False,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
obj: str = "Series",
) -> None:
"""
Check that left and right Series are equal.

Expand Down Expand Up @@ -1185,8 +1201,11 @@ def assert_series_equal(
right._internal_get_values(),
check_dtype=check_dtype,
)
elif is_interval_dtype(left) or is_interval_dtype(right):
assert_interval_array_equal(left.array, right.array)
elif is_interval_dtype(left) or is_interval_dtype(left):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now just elif is_interval_dtype(left) or should the second condition not have been changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I think change was my mistake. Have removed it.

# must cast to interval dtype to keep mypy happy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the complaint on this? This changes the actual assertions being done I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't cast to IntervalArray I get these errors

pandas/util/testing.py:1211: error: Argument 1 to "assert_interval_array_equal" has incompatible type "ExtensionArray"; expected "IntervalArray"
pandas/util/testing.py:1211: error: Argument 2 to "assert_interval_array_equal" has incompatible type "ExtensionArray"; expected "IntervalArray"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the asserts though, they weren't needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned by @jreback should use cast from the typing module here - don't want to actually construct a new object via a call to IntervalArray

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can delete this comment

left_array = cast(IntervalArray, left.array)
right_array = cast(IntervalArray, right.array)
assert_interval_array_equal(left_array, right_array)
elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype):
# .values is an ndarray, but ._values is the ExtensionArray.
# TODO: Use .array
Expand Down Expand Up @@ -1221,21 +1240,21 @@ def assert_series_equal(

# This could be refactored to use the NDFrame.equals method
def assert_frame_equal(
left,
right,
check_dtype=True,
check_index_type="equiv",
check_column_type="equiv",
check_frame_type=True,
check_less_precise=False,
check_names=True,
by_blocks=False,
check_exact=False,
check_datetimelike_compat=False,
check_categorical=True,
check_like=False,
obj="DataFrame",
):
left: DataFrame,
right: DataFrame,
check_dtype: bool = True,
check_index_type: str = "equiv",
check_column_type: str = "equiv",
check_frame_type: bool = True,
check_less_precise: bool = False,
check_names: bool = True,
by_blocks: bool = False,
check_exact: bool = False,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
check_like: bool = False,
obj: str = "DataFrame",
) -> None:
"""
Check that left and right DataFrame are equal.

Expand Down Expand Up @@ -1403,7 +1422,11 @@ def assert_frame_equal(
)


def assert_equal(left, right, **kwargs):
def assert_equal(
left: Union[DataFrame, AnyArrayLike],
right: Union[DataFrame, AnyArrayLike],
**kwargs
) -> None:
"""
Wrapper for tm.assert_*_equal to dispatch to the appropriate test function.

Expand All @@ -1415,27 +1438,36 @@ def assert_equal(left, right, **kwargs):
"""
__tracebackhide__ = True

if isinstance(left, pd.Index):
if isinstance(left, Index):
right = cast(Index, right)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't need any of the casts here; isinstance will narrow the types appropriately

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I remove the cast I get the following error.

pandas/util/testing.py:1440: error: Argument 2 to "assert_index_equal" has incompatible type "Union[DataFrame, Index]"; expected "Index"
pandas/util/testing.py:1440: error: Argument 2 to "assert_index_equal" has incompatible type "Union[DataFrame, Any]"; expected "Index"

assert_index_equal(left, right, **kwargs)
elif isinstance(left, pd.Series):
elif isinstance(left, Series):
right = cast(Series, right)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these casts shouldn't be added here. casts are for when we know better than mypy. in these cases the second argument to assert_series_equal etc may be anything. since the is_instance checks are performed in assert_series_equal etc.

it may be better to type all the the assert functions as

def assert_index_equal(
    left: Any,
    right: Any,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions do expect certain types (as described in the docstring), so I would think we'd want to be more explicit here, even if we need to add a few casts. These casts aren't happening at runtime anyway, right? https://docs.python.org/3/library/typing.html#typing.cast

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I thing the docstrings are wrong. The correct behaviour is that the assertion helpers take any two objects and raise an AssertionError if the objects are of the wrong type and not equal.

If the objects passed were expected to be of a certain type I would expect to get a TypeError.

>>> import pandas as pd
>>> import pandas.util.testing as tm
>>> pd.__version__
'0.26.0.dev0+984.g5ba71f883.dirty'
>>>
>>> tm.assert_frame_equal(pd.DataFrame(), pd.Series())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 1330, in assert_frame_equal
    _check_isinstance(left, right, DataFrame)
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 392, in _check_isinstance
    err_msg.format(name=cls_name, exp_type=cls, act_type=type(right))
AssertionError: DataFrame Expected type <class 'pandas.core.frame.DataFrame'>, found <class 'pandas.core.series.Series'> instead
>>>
>>> tm.assert_frame_equal(pd.DataFrame(), 'foo')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 1330, in assert_frame_equal
    _check_isinstance(left, right, DataFrame)
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 392, in _check_isinstance
    err_msg.format(name=cls_name, exp_type=cls, act_type=type(right))
AssertionError: DataFrame Expected type <class 'pandas.core.frame.DataFrame'>, found <class 'str'> instead
>>>
>>> tm.assert_frame_equal(pd.Series(), pd.Series())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 1330, in assert_frame_equal
    _check_isinstance(left, right, DataFrame)
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 388, in _check_isinstance
    err_msg.format(name=cls_name, exp_type=cls, act_type=type(left))
AssertionError: DataFrame Expected type <class 'pandas.core.frame.DataFrame'>, found <class 'pandas.core.series.Series'> instead
>>>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that if they are not typed as Any then this could trip up users who depend on this instance checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting a change to the implementation here? When I call an assert_* function I would not expect it to raise something other than an assertion error. So I wouldn't expect to get a TypeError. For example if I compare 3 with a bool with unittest.TestCase.assertIsInstance I just get an AssertionError, as I expect.

In [6]: unittest.TestCase.assertIsInstance(unittest.TestCase(), 3, bool)                       
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-6-4b36bb9f21c2> in <module>
----> 1 unittest.TestCase.assertIsInstance(unittest.TestCase(), 3, bool)

~/apps/miniconda3/envs/pandas-dev/lib/python3.7/unittest/case.py in assertIsInstance(self, obj, cls, msg)
   1274         if not isinstance(obj, cls):
   1275             standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
-> 1276             self.fail(self._formatMessage(msg, standardMsg))
   1277 
   1278     def assertNotIsInstance(self, obj, cls, msg=None):

~/apps/miniconda3/envs/pandas-dev/lib/python3.7/unittest/case.py in fail(self, msg)
    691     def fail(self, msg=None):
    692         """Fail immediately, with the given message."""
--> 693         raise self.failureException(msg)
    694 
    695     def assertFalse(self, expr, msg=None):

AssertionError: 3 is not an instance of <class 'bool'>

Am I missing something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternate approach here may be to only type the left argument as required and leave right typed to Any. I guess that is exactly what we are doing, since we just do an isinstance(left, ...) check to determine the function to call

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise we could narrow the types of both left and right and fall through to a failure message at the end that they aren’t of the same type, but maybe better inspected as a follow up; not sure we want that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR I think it makes sense to follow the existing implementation as described in the docstrings. If you call assert_index_equal I expect to be comparing two indexes, and if I feed something else to the function I’d expect a failure.
What mypy does is catch that failure at compile time when that is possible. Surely this is a good thing?
Again my understanding is that the casts aren’t changing the runtime behaviour so the current PR is faithful to the existing implementation.
In terms of semantics it also matches that of similar assertEquals methods in unittest.TestCase.
Shall we postpone deeper changes and refactorings to a follow up PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the following code

def test_raises():
    a = pd.Series([1])
    b = np.array([1])
    with pytest.raises(AssertionError):
        tm.assert_series_equal(a, b)

Is that valid code? IIUC mypy will complain about it, since b isn't a Series.

Copy link
Contributor Author

@samuelsinayoko samuelsinayoko Jan 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well for me that's not valid code, I would expect this to be caught by mypy since we're comparing a series to an array. But my understanding is that @simonjayhawkins thinkgs we should leave mypy out of this and allow any object in these functions.
@TomAugspurger see also my response to @simonjayhawkins: #29364 (comment)

assert_series_equal(left, right, **kwargs)
elif isinstance(left, pd.DataFrame):
elif isinstance(left, DataFrame):
right = cast(DataFrame, right)
assert_frame_equal(left, right, **kwargs)
elif isinstance(left, IntervalArray):
right = cast(IntervalArray, right)
assert_interval_array_equal(left, right, **kwargs)
elif isinstance(left, PeriodArray):
right = cast(PeriodArray, right)
assert_period_array_equal(left, right, **kwargs)
elif isinstance(left, DatetimeArray):
right = cast(DatetimeArray, right)
assert_datetime_array_equal(left, right, **kwargs)
elif isinstance(left, TimedeltaArray):
right = cast(TimedeltaArray, right)
assert_timedelta_array_equal(left, right, **kwargs)
elif isinstance(left, ExtensionArray):
right = cast(ExtensionArray, right)
assert_extension_array_equal(left, right, **kwargs)
elif isinstance(left, np.ndarray):
right = cast(np.ndarray, right)
assert_numpy_array_equal(left, right, **kwargs)
elif isinstance(left, str):
assert kwargs == {}
return left == right
assert left == right
else:
raise NotImplementedError(type(left))

Expand Down Expand Up @@ -1497,12 +1529,12 @@ def to_array(obj):


def assert_sp_array_equal(
left,
right,
check_dtype=True,
check_kind=True,
check_fill_value=True,
consolidate_block_indices=False,
left: pd.SparseArray,
right: pd.SparseArray,
check_dtype: bool = True,
check_kind: bool = True,
check_fill_value: bool = True,
consolidate_block_indices: bool = False,
):
"""Check that the left and right SparseArray are equal.

Expand Down