|
40 | 40 |
|
41 | 41 | def rands(n):
|
42 | 42 | choices = string.ascii_letters + string.digits
|
43 |
| - return ''.join([random.choice(choices) for _ in xrange(n)]) |
| 43 | + return ''.join(random.choice(choices) for _ in xrange(n)) |
44 | 44 |
|
45 | 45 |
|
46 | 46 | def randu(n):
|
@@ -749,13 +749,46 @@ def stdin_encoding(encoding=None):
|
749 | 749 | sys.stdin = _stdin
|
750 | 750 |
|
751 | 751 |
|
752 |
| -def assert_warns(warning, f, *args, **kwargs): |
| 752 | +@contextmanager |
| 753 | +def assert_produces_warning(expected_warning=Warning, filter_level="always"): |
753 | 754 | """
|
754 |
| - From: http://stackoverflow.com/questions/3892218/how-to-test-with-pythons-unittest-that-a-warning-has-been-thrown |
| 755 | + Context manager for running code that expects to raise (or not raise) |
| 756 | + warnings. Checks that code raises the expected warning and only the |
| 757 | + expected warning. Pass ``False`` or ``None`` to check that it does *not* |
| 758 | + raise a warning. Defaults to ``exception.Warning``, baseclass of all |
| 759 | + Warnings. (basically a wrapper around ``warnings.catch_warnings``). |
| 760 | +
|
| 761 | + >>> import warnings |
| 762 | + >>> with assert_produces_warning(): |
| 763 | + ... warnings.warn(UserWarning()) |
| 764 | + ... |
| 765 | + >>> with assert_produces_warning(False): |
| 766 | + ... warnings.warn(RuntimeWarning()) |
| 767 | + ... |
| 768 | + Traceback (most recent call last): |
| 769 | + ... |
| 770 | + AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. |
| 771 | + >>> with assert_produces_warning(UserWarning): |
| 772 | + ... warnings.warn(RuntimeWarning()) |
| 773 | + Traceback (most recent call last): |
| 774 | + ... |
| 775 | + AssertionError: Did not see expected warning of class 'UserWarning'. |
| 776 | +
|
| 777 | + ..warn:: This is *not* thread-safe. |
755 | 778 | """
|
756 |
| - with warnings.catch_warnings(record=True) as warning_list: |
757 |
| - warnings.simplefilter('always') |
758 |
| - f(*args, **kwargs) |
759 |
| - msg = '{0!r} not raised'.format(warning) |
760 |
| - assert any(issubclass(item.category, warning) |
761 |
| - for item in warning_list), msg |
| 779 | + with warnings.catch_warnings(record=True) as w: |
| 780 | + saw_warning = False |
| 781 | + warnings.simplefilter(filter_level) |
| 782 | + yield w |
| 783 | + extra_warnings = [] |
| 784 | + for actual_warning in w: |
| 785 | + if (expected_warning and issubclass(actual_warning.category, |
| 786 | + expected_warning)): |
| 787 | + saw_warning = True |
| 788 | + else: |
| 789 | + extra_warnings.append(actual_warning.category.__name__) |
| 790 | + if expected_warning: |
| 791 | + assert saw_warning, ("Did not see expected warning of class %r." |
| 792 | + % expected_warning.__name__) |
| 793 | + assert not extra_warnings, ("Caused unexpected warning(s): %r." |
| 794 | + % extra_warnings) |
0 commit comments