diff --git a/pandas/_testing.py b/pandas/_testing.py index cf6272edc4c05..a4fdb390abf42 100644 --- a/pandas/_testing.py +++ b/pandas/_testing.py @@ -6,6 +6,7 @@ import gzip import operator import os +import re from shutil import rmtree import string import tempfile @@ -2546,10 +2547,11 @@ def wrapper(*args, **kwargs): @contextmanager def assert_produces_warning( - expected_warning=Warning, + expected_warning: Optional[Union[Type[Warning], bool]] = Warning, filter_level="always", - check_stacklevel=True, - raise_on_extra_warnings=True, + check_stacklevel: bool = True, + raise_on_extra_warnings: bool = True, + match: Optional[str] = None, ): """ Context manager for running code expected to either raise a specific @@ -2584,6 +2586,8 @@ class for all warnings. To check that no warning is returned, raise_on_extra_warnings : bool, default True Whether extra warnings not of the type `expected_warning` should cause the test to fail. + match : str, optional + Match warning message. Examples -------- @@ -2610,28 +2614,28 @@ class for all warnings. To check that no warning is returned, with warnings.catch_warnings(record=True) as w: saw_warning = False + matched_message = False + warnings.simplefilter(filter_level) yield w extra_warnings = [] for actual_warning in w: - if expected_warning and issubclass( - actual_warning.category, expected_warning - ): + if not expected_warning: + continue + + expected_warning = cast(Type[Warning], expected_warning) + if issubclass(actual_warning.category, expected_warning): saw_warning = True if check_stacklevel and issubclass( actual_warning.category, (FutureWarning, DeprecationWarning) ): - from inspect import getframeinfo, stack + _assert_raised_with_correct_stacklevel(actual_warning) + + if match is not None and re.search(match, str(actual_warning.message)): + matched_message = True - caller = getframeinfo(stack()[2][0]) - msg = ( - "Warning not set with correct stacklevel. " - f"File where warning is raised: {actual_warning.filename} != " - f"{caller.filename}. Warning message: {actual_warning.message}" - ) - assert actual_warning.filename == caller.filename, msg else: extra_warnings.append( ( @@ -2641,18 +2645,41 @@ class for all warnings. To check that no warning is returned, actual_warning.lineno, ) ) + if expected_warning: - msg = ( - f"Did not see expected warning of class " - f"{repr(expected_warning.__name__)}" - ) - assert saw_warning, msg + expected_warning = cast(Type[Warning], expected_warning) + if not saw_warning: + raise AssertionError( + f"Did not see expected warning of class " + f"{repr(expected_warning.__name__)}" + ) + + if match and not matched_message: + raise AssertionError( + f"Did not see warning {repr(expected_warning.__name__)} " + f"matching {match}" + ) + if raise_on_extra_warnings and extra_warnings: raise AssertionError( f"Caused unexpected warning(s): {repr(extra_warnings)}" ) +def _assert_raised_with_correct_stacklevel( + actual_warning: warnings.WarningMessage, +) -> None: + from inspect import getframeinfo, stack + + caller = getframeinfo(stack()[3][0]) + msg = ( + "Warning not set with correct stacklevel. " + f"File where warning is raised: {actual_warning.filename} != " + f"{caller.filename}. Warning message: {actual_warning.message}" + ) + assert actual_warning.filename == caller.filename, msg + + class RNGContext: """ Context manager to set the numpy random number generator speed. Returns diff --git a/pandas/tests/util/test_assert_produces_warning.py b/pandas/tests/util/test_assert_produces_warning.py index 87765c909938d..5f87824713175 100644 --- a/pandas/tests/util/test_assert_produces_warning.py +++ b/pandas/tests/util/test_assert_produces_warning.py @@ -1,10 +1,58 @@ +"""" +Test module for testing ``pandas._testing.assert_produces_warning``. +""" import warnings import pytest +from pandas.errors import DtypeWarning, PerformanceWarning + import pandas._testing as tm +@pytest.fixture( + params=[ + RuntimeWarning, + ResourceWarning, + UserWarning, + FutureWarning, + DeprecationWarning, + PerformanceWarning, + DtypeWarning, + ], +) +def category(request): + """ + Return unique warning. + + Useful for testing behavior of tm.assert_produces_warning with various categories. + """ + return request.param + + +@pytest.fixture( + params=[ + (RuntimeWarning, UserWarning), + (UserWarning, FutureWarning), + (FutureWarning, RuntimeWarning), + (DeprecationWarning, PerformanceWarning), + (PerformanceWarning, FutureWarning), + (DtypeWarning, DeprecationWarning), + (ResourceWarning, DeprecationWarning), + (FutureWarning, DeprecationWarning), + ], + ids=lambda x: type(x).__name__, +) +def pair_different_warnings(request): + """ + Return pair or different warnings. + + Useful for testing how several different warnings are handled + in tm.assert_produces_warning. + """ + return request.param + + def f(): warnings.warn("f1", FutureWarning) warnings.warn("f2", RuntimeWarning) @@ -20,3 +68,87 @@ def test_assert_produces_warning_honors_filter(): with tm.assert_produces_warning(RuntimeWarning, raise_on_extra_warnings=False): f() + + +@pytest.mark.parametrize( + "message, match", + [ + ("", None), + ("", ""), + ("Warning message", r".*"), + ("Warning message", "War"), + ("Warning message", r"[Ww]arning"), + ("Warning message", "age"), + ("Warning message", r"age$"), + ("Message 12-234 with numbers", r"\d{2}-\d{3}"), + ("Message 12-234 with numbers", r"^Mes.*\d{2}-\d{3}"), + ("Message 12-234 with numbers", r"\d{2}-\d{3}\s\S+"), + ("Message, which we do not match", None), + ], +) +def test_catch_warning_category_and_match(category, message, match): + with tm.assert_produces_warning(category, match=match): + warnings.warn(message, category) + + +@pytest.mark.parametrize( + "message, match", + [ + ("Warning message", "Not this message"), + ("Warning message", "warning"), + ("Warning message", r"\d+"), + ], +) +def test_fail_to_match(category, message, match): + msg = f"Did not see warning {repr(category.__name__)} matching" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(category, match=match): + warnings.warn(message, category) + + +def test_fail_to_catch_actual_warning(pair_different_warnings): + expected_category, actual_category = pair_different_warnings + match = "Did not see expected warning of class" + with pytest.raises(AssertionError, match=match): + with tm.assert_produces_warning(expected_category): + warnings.warn("warning message", actual_category) + + +def test_ignore_extra_warning(pair_different_warnings): + expected_category, extra_category = pair_different_warnings + with tm.assert_produces_warning(expected_category, raise_on_extra_warnings=False): + warnings.warn("Expected warning", expected_category) + warnings.warn("Unexpected warning OK", extra_category) + + +def test_raise_on_extra_warning(pair_different_warnings): + expected_category, extra_category = pair_different_warnings + match = r"Caused unexpected warning\(s\)" + with pytest.raises(AssertionError, match=match): + with tm.assert_produces_warning(expected_category): + warnings.warn("Expected warning", expected_category) + warnings.warn("Unexpected warning NOT OK", extra_category) + + +def test_same_category_different_messages_first_match(): + category = UserWarning + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Match this", category) + warnings.warn("Do not match that", category) + warnings.warn("Do not match that either", category) + + +def test_same_category_different_messages_last_match(): + category = DeprecationWarning + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Do not match that", category) + warnings.warn("Do not match that either", category) + warnings.warn("Match this", category) + + +def test_right_category_wrong_match_raises(pair_different_warnings): + target_category, other_category = pair_different_warnings + with pytest.raises(AssertionError, match="Did not see warning.*matching"): + with tm.assert_produces_warning(target_category, match=r"^Match this"): + warnings.warn("Do not match it", target_category) + warnings.warn("Match this", other_category)