Skip to content

26302 add typing to assert star equal funcs #29364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 85 additions & 56 deletions pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from shutil import rmtree
import string
import tempfile
from typing import Union, cast
from typing import Optional, Union, cast
import warnings
import zipfile

Expand Down Expand Up @@ -53,6 +53,7 @@
Series,
bdate_range,
)
from pandas._typing import AnyArrayLike
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AnyArrayLike resolves to Any.

If it adds value in the form of code documentation then OK, but mypy is effectively not checking these annotations.

from pandas.core.algorithms import take_1d
from pandas.core.arrays import (
DatetimeArray,
Expand Down Expand Up @@ -806,8 +807,12 @@ def assert_is_sorted(seq):


def assert_categorical_equal(
left, right, check_dtype=True, check_category_order=True, obj="Categorical"
):
left: Categorical,
right: Categorical,
check_dtype: bool = True,
check_category_order: bool = True,
obj: str = "Categorical",
) -> None:
"""Test that Categoricals are equivalent.

Parameters
Expand Down Expand Up @@ -852,7 +857,12 @@ def assert_categorical_equal(
assert_attr_equal("ordered", left, right, obj=obj)


def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"):
def assert_interval_array_equal(
left: IntervalArray,
right: IntervalArray,
exact: str = "equiv",
obj: str = "IntervalArray",
) -> None:
"""Test that two IntervalArrays are equivalent.

Parameters
Expand All @@ -867,8 +877,6 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
Specify object name being compared, internally used to show appropriate
assertion message
"""
_check_isinstance(left, right, IntervalArray)

assert_index_equal(
left.left, right.left, exact=exact, obj="{obj}.left".format(obj=obj)
)
Expand All @@ -878,7 +886,9 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
assert_attr_equal("closed", left, right, obj=obj)


def assert_period_array_equal(left, right, obj="PeriodArray"):
def assert_period_array_equal(
left: PeriodArray, right: PeriodArray, obj: str = "PeriodArray"
) -> None:
_check_isinstance(left, right, PeriodArray)

assert_numpy_array_equal(
Expand All @@ -887,7 +897,9 @@ def assert_period_array_equal(left, right, obj="PeriodArray"):
assert_attr_equal("freq", left, right, obj=obj)


def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
def assert_datetime_array_equal(
left: DatetimeArray, right: DatetimeArray, obj: str = "DatetimeArray"
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, DatetimeArray)

Expand All @@ -896,7 +908,9 @@ def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
assert_attr_equal("tz", left, right, obj=obj)


def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"):
def assert_timedelta_array_equal(
left: TimedeltaArray, right: TimedeltaArray, obj: str = "TimedeltaArray"
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, TimedeltaArray)
assert_numpy_array_equal(left._data, right._data, obj="{obj}._data".format(obj=obj))
Expand Down Expand Up @@ -931,14 +945,14 @@ def raise_assert_detail(obj, message, left, right, diff=None):


def assert_numpy_array_equal(
left,
right,
strict_nan=False,
check_dtype=True,
err_msg=None,
check_same=None,
obj="numpy array",
):
left: np.ndarray,
right: np.ndarray,
strict_nan: bool = False,
check_dtype: bool = True,
err_msg: Optional[str] = None,
check_same: Optional[str] = None,
obj: str = "numpy array",
) -> None:
""" Checks that 'np.ndarray' is equivalent

Parameters
Expand Down Expand Up @@ -1067,18 +1081,18 @@ def assert_extension_array_equal(

# This could be refactored to use the NDFrame.equals method
def assert_series_equal(
left,
right,
check_dtype=True,
check_index_type="equiv",
check_series_type=True,
check_less_precise=False,
check_names=True,
check_exact=False,
check_datetimelike_compat=False,
check_categorical=True,
obj="Series",
):
left: Series,
right: Series,
check_dtype: bool = True,
check_index_type: str = "equiv",
check_series_type: bool = True,
check_less_precise: bool = False,
check_names: bool = True,
check_exact: bool = False,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
obj: str = "Series",
) -> None:
"""
Check that left and right Series are equal.

Expand Down Expand Up @@ -1186,7 +1200,9 @@ def assert_series_equal(
check_dtype=check_dtype,
)
elif is_interval_dtype(left) or is_interval_dtype(right):
assert_interval_array_equal(left.array, right.array)
left_array = cast(IntervalArray, left.array)
right_array = cast(IntervalArray, right.array)
assert_interval_array_equal(left_array, right_array)
elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype):
# .values is an ndarray, but ._values is the ExtensionArray.
# TODO: Use .array
Expand Down Expand Up @@ -1221,21 +1237,21 @@ def assert_series_equal(

# This could be refactored to use the NDFrame.equals method
def assert_frame_equal(
left,
right,
check_dtype=True,
check_index_type="equiv",
check_column_type="equiv",
check_frame_type=True,
check_less_precise=False,
check_names=True,
by_blocks=False,
check_exact=False,
check_datetimelike_compat=False,
check_categorical=True,
check_like=False,
obj="DataFrame",
):
left: DataFrame,
right: DataFrame,
check_dtype: bool = True,
check_index_type: str = "equiv",
check_column_type: str = "equiv",
check_frame_type: bool = True,
check_less_precise: bool = False,
check_names: bool = True,
by_blocks: bool = False,
check_exact: bool = False,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
check_like: bool = False,
obj: str = "DataFrame",
) -> None:
"""
Check that left and right DataFrame are equal.

Expand Down Expand Up @@ -1400,7 +1416,11 @@ def assert_frame_equal(
)


def assert_equal(left, right, **kwargs):
def assert_equal(
left: Union[DataFrame, AnyArrayLike],
right: Union[DataFrame, AnyArrayLike],
**kwargs,
) -> None:
"""
Wrapper for tm.assert_*_equal to dispatch to the appropriate test function.

Expand All @@ -1412,27 +1432,36 @@ def assert_equal(left, right, **kwargs):
"""
__tracebackhide__ = True

if isinstance(left, pd.Index):
if isinstance(left, Index):
right = cast(Index, right)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't need any of the casts here; isinstance will narrow the types appropriately

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I remove the cast I get the following error.

pandas/util/testing.py:1440: error: Argument 2 to "assert_index_equal" has incompatible type "Union[DataFrame, Index]"; expected "Index"
pandas/util/testing.py:1440: error: Argument 2 to "assert_index_equal" has incompatible type "Union[DataFrame, Any]"; expected "Index"

assert_index_equal(left, right, **kwargs)
elif isinstance(left, pd.Series):
elif isinstance(left, Series):
right = cast(Series, right)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these casts shouldn't be added here. casts are for when we know better than mypy. in these cases the second argument to assert_series_equal etc may be anything. since the is_instance checks are performed in assert_series_equal etc.

it may be better to type all the the assert functions as

def assert_index_equal(
    left: Any,
    right: Any,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions do expect certain types (as described in the docstring), so I would think we'd want to be more explicit here, even if we need to add a few casts. These casts aren't happening at runtime anyway, right? https://docs.python.org/3/library/typing.html#typing.cast

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I thing the docstrings are wrong. The correct behaviour is that the assertion helpers take any two objects and raise an AssertionError if the objects are of the wrong type and not equal.

If the objects passed were expected to be of a certain type I would expect to get a TypeError.

>>> import pandas as pd
>>> import pandas.util.testing as tm
>>> pd.__version__
'0.26.0.dev0+984.g5ba71f883.dirty'
>>>
>>> tm.assert_frame_equal(pd.DataFrame(), pd.Series())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 1330, in assert_frame_equal
    _check_isinstance(left, right, DataFrame)
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 392, in _check_isinstance
    err_msg.format(name=cls_name, exp_type=cls, act_type=type(right))
AssertionError: DataFrame Expected type <class 'pandas.core.frame.DataFrame'>, found <class 'pandas.core.series.Series'> instead
>>>
>>> tm.assert_frame_equal(pd.DataFrame(), 'foo')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 1330, in assert_frame_equal
    _check_isinstance(left, right, DataFrame)
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 392, in _check_isinstance
    err_msg.format(name=cls_name, exp_type=cls, act_type=type(right))
AssertionError: DataFrame Expected type <class 'pandas.core.frame.DataFrame'>, found <class 'str'> instead
>>>
>>> tm.assert_frame_equal(pd.Series(), pd.Series())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 1330, in assert_frame_equal
    _check_isinstance(left, right, DataFrame)
  File "C:\Users\simon\pandas\pandas\util\testing.py", line 388, in _check_isinstance
    err_msg.format(name=cls_name, exp_type=cls, act_type=type(left))
AssertionError: DataFrame Expected type <class 'pandas.core.frame.DataFrame'>, found <class 'pandas.core.series.Series'> instead
>>>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that if they are not typed as Any then this could trip up users who depend on this instance checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting a change to the implementation here? When I call an assert_* function I would not expect it to raise something other than an assertion error. So I wouldn't expect to get a TypeError. For example if I compare 3 with a bool with unittest.TestCase.assertIsInstance I just get an AssertionError, as I expect.

In [6]: unittest.TestCase.assertIsInstance(unittest.TestCase(), 3, bool)                       
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-6-4b36bb9f21c2> in <module>
----> 1 unittest.TestCase.assertIsInstance(unittest.TestCase(), 3, bool)

~/apps/miniconda3/envs/pandas-dev/lib/python3.7/unittest/case.py in assertIsInstance(self, obj, cls, msg)
   1274         if not isinstance(obj, cls):
   1275             standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
-> 1276             self.fail(self._formatMessage(msg, standardMsg))
   1277 
   1278     def assertNotIsInstance(self, obj, cls, msg=None):

~/apps/miniconda3/envs/pandas-dev/lib/python3.7/unittest/case.py in fail(self, msg)
    691     def fail(self, msg=None):
    692         """Fail immediately, with the given message."""
--> 693         raise self.failureException(msg)
    694 
    695     def assertFalse(self, expr, msg=None):

AssertionError: 3 is not an instance of <class 'bool'>

Am I missing something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternate approach here may be to only type the left argument as required and leave right typed to Any. I guess that is exactly what we are doing, since we just do an isinstance(left, ...) check to determine the function to call

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise we could narrow the types of both left and right and fall through to a failure message at the end that they aren’t of the same type, but maybe better inspected as a follow up; not sure we want that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR I think it makes sense to follow the existing implementation as described in the docstrings. If you call assert_index_equal I expect to be comparing two indexes, and if I feed something else to the function I’d expect a failure.
What mypy does is catch that failure at compile time when that is possible. Surely this is a good thing?
Again my understanding is that the casts aren’t changing the runtime behaviour so the current PR is faithful to the existing implementation.
In terms of semantics it also matches that of similar assertEquals methods in unittest.TestCase.
Shall we postpone deeper changes and refactorings to a follow up PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the following code

def test_raises():
    a = pd.Series([1])
    b = np.array([1])
    with pytest.raises(AssertionError):
        tm.assert_series_equal(a, b)

Is that valid code? IIUC mypy will complain about it, since b isn't a Series.

Copy link
Contributor Author

@samuelsinayoko samuelsinayoko Jan 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well for me that's not valid code, I would expect this to be caught by mypy since we're comparing a series to an array. But my understanding is that @simonjayhawkins thinkgs we should leave mypy out of this and allow any object in these functions.
@TomAugspurger see also my response to @simonjayhawkins: #29364 (comment)

assert_series_equal(left, right, **kwargs)
elif isinstance(left, pd.DataFrame):
elif isinstance(left, DataFrame):
right = cast(DataFrame, right)
assert_frame_equal(left, right, **kwargs)
elif isinstance(left, IntervalArray):
right = cast(IntervalArray, right)
assert_interval_array_equal(left, right, **kwargs)
elif isinstance(left, PeriodArray):
right = cast(PeriodArray, right)
assert_period_array_equal(left, right, **kwargs)
elif isinstance(left, DatetimeArray):
right = cast(DatetimeArray, right)
assert_datetime_array_equal(left, right, **kwargs)
elif isinstance(left, TimedeltaArray):
right = cast(TimedeltaArray, right)
assert_timedelta_array_equal(left, right, **kwargs)
elif isinstance(left, ExtensionArray):
right = cast(ExtensionArray, right)
assert_extension_array_equal(left, right, **kwargs)
elif isinstance(left, np.ndarray):
right = cast(np.ndarray, right)
assert_numpy_array_equal(left, right, **kwargs)
elif isinstance(left, str):
assert kwargs == {}
return left == right
assert left == right
else:
raise NotImplementedError(type(left))

Expand Down Expand Up @@ -1494,12 +1523,12 @@ def to_array(obj):


def assert_sp_array_equal(
left,
right,
check_dtype=True,
check_kind=True,
check_fill_value=True,
consolidate_block_indices=False,
left: pd.SparseArray,
right: pd.SparseArray,
check_dtype: bool = True,
check_kind: bool = True,
check_fill_value: bool = True,
consolidate_block_indices: bool = False,
):
"""Check that the left and right SparseArray are equal.

Expand Down