diff --git a/pandas/_testing/_warnings.py b/pandas/_testing/_warnings.py index 9e89e09e418b3..1a8fe71ae3728 100644 --- a/pandas/_testing/_warnings.py +++ b/pandas/_testing/_warnings.py @@ -7,6 +7,7 @@ import re import sys from typing import ( + Literal, Sequence, Type, cast, @@ -17,7 +18,9 @@ @contextmanager def assert_produces_warning( expected_warning: type[Warning] | bool | None = Warning, - filter_level="always", + filter_level: Literal[ + "error", "ignore", "always", "default", "module", "once" + ] = "always", check_stacklevel: bool = True, raise_on_extra_warnings: bool = True, match: str | None = None, diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py index ab42fcd92a3d9..b5e288690decb 100644 --- a/pandas/_testing/asserters.py +++ b/pandas/_testing/asserters.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing import cast +from typing import ( + Literal, + cast, +) import warnings import numpy as np @@ -10,6 +13,7 @@ no_default, ) from pandas._libs.missing import is_matching_na +from pandas._libs.sparse import SparseIndex import pandas._libs.testing as _testing from pandas.util._exceptions import find_stack_level @@ -61,7 +65,7 @@ def assert_almost_equal( left, right, - check_dtype: bool | str = "equiv", + check_dtype: bool | Literal["equiv"] = "equiv", check_less_precise: bool | int | NoDefault = no_default, rtol: float = 1.0e-5, atol: float = 1.0e-8, @@ -169,9 +173,8 @@ def assert_almost_equal( assert_class_equal(left, right, obj=obj) # if we have "equiv", this becomes True - check_dtype = bool(check_dtype) _testing.assert_almost_equal( - left, right, check_dtype=check_dtype, rtol=rtol, atol=atol, **kwargs + left, right, check_dtype=bool(check_dtype), rtol=rtol, atol=atol, **kwargs ) @@ -686,7 +689,7 @@ def assert_numpy_array_equal( left, right, strict_nan=False, - check_dtype=True, + check_dtype: bool | Literal["equiv"] = True, err_msg=None, check_same=None, obj="numpy array", @@ -765,7 +768,7 @@ def _raise(left, right, err_msg): def assert_extension_array_equal( left, right, - check_dtype=True, + check_dtype: bool | Literal["equiv"] = True, index_values=None, check_less_precise=no_default, check_exact=False, @@ -858,7 +861,7 @@ def assert_extension_array_equal( _testing.assert_almost_equal( left_valid, right_valid, - check_dtype=check_dtype, + check_dtype=bool(check_dtype), rtol=rtol, atol=atol, obj="ExtensionArray", @@ -870,7 +873,7 @@ def assert_extension_array_equal( def assert_series_equal( left, right, - check_dtype=True, + check_dtype: bool | Literal["equiv"] = True, check_index_type="equiv", check_series_type=True, check_less_precise=no_default, @@ -1064,7 +1067,7 @@ def assert_series_equal( right._values, rtol=rtol, atol=atol, - check_dtype=check_dtype, + check_dtype=bool(check_dtype), obj=str(obj), index_values=np.asarray(left.index), ) @@ -1100,7 +1103,7 @@ def assert_series_equal( right._values, rtol=rtol, atol=atol, - check_dtype=check_dtype, + check_dtype=bool(check_dtype), obj=str(obj), index_values=np.asarray(left.index), ) @@ -1125,7 +1128,7 @@ def assert_series_equal( def assert_frame_equal( left, right, - check_dtype=True, + check_dtype: bool | Literal["equiv"] = True, check_index_type="equiv", check_column_type="equiv", check_frame_type=True, @@ -1403,8 +1406,8 @@ def assert_sp_array_equal(left, right): assert_numpy_array_equal(left.sp_values, right.sp_values) # SparseIndex comparison - assert isinstance(left.sp_index, pd._libs.sparse.SparseIndex) - assert isinstance(right.sp_index, pd._libs.sparse.SparseIndex) + assert isinstance(left.sp_index, SparseIndex) + assert isinstance(right.sp_index, SparseIndex) left_index = left.sp_index right_index = right.sp_index diff --git a/pandas/util/_test_decorators.py b/pandas/util/_test_decorators.py index d7eba6b8319fb..bbcf984e68b4b 100644 --- a/pandas/util/_test_decorators.py +++ b/pandas/util/_test_decorators.py @@ -35,6 +35,7 @@ def test_foo(): from pandas._config import get_option +from pandas._typing import F from pandas.compat import ( IS64, is_platform_windows, @@ -216,7 +217,7 @@ def skip_if_np_lt(ver_str: str, *args, reason: str | None = None): ) -def parametrize_fixture_doc(*args): +def parametrize_fixture_doc(*args) -> Callable[[F], F]: """ Intended for use as a decorator for parametrized fixture, this function will wrap the decorated function with a pytest