diff --git a/pandas/util/_decorators.py b/pandas/util/_decorators.py index 2e91d76b0fe98..d10d3a1f71fe6 100644 --- a/pandas/util/_decorators.py +++ b/pandas/util/_decorators.py @@ -294,7 +294,6 @@ def update(self, *args, **kwargs) -> None: """ Update self.params with supplied args. """ - if isinstance(self.params, dict): self.params.update(*args, **kwargs) diff --git a/pandas/util/_depr_module.py b/pandas/util/_depr_module.py index 99dafdd760d26..5733663dd7ab3 100644 --- a/pandas/util/_depr_module.py +++ b/pandas/util/_depr_module.py @@ -4,11 +4,13 @@ """ import importlib +from typing import Iterable import warnings class _DeprecatedModule: - """ Class for mocking deprecated modules. + """ + Class for mocking deprecated modules. Parameters ---------- @@ -34,7 +36,7 @@ def __init__(self, deprmod, deprmodto=None, removals=None, moved=None): # For introspection purposes. self.self_dir = frozenset(dir(type(self))) - def __dir__(self): + def __dir__(self) -> Iterable[str]: deprmodule = self._import_deprmod() return dir(deprmodule) diff --git a/pandas/util/_doctools.py b/pandas/util/_doctools.py index 91972fed7a3bb..8fd4566d7763b 100644 --- a/pandas/util/_doctools.py +++ b/pandas/util/_doctools.py @@ -1,3 +1,5 @@ +from typing import Optional, Tuple + import numpy as np import pandas as pd @@ -9,24 +11,27 @@ class TablePlotter: Used in merging.rst """ - def __init__(self, cell_width=0.37, cell_height=0.25, font_size=7.5): + def __init__( + self, + cell_width: float = 0.37, + cell_height: float = 0.25, + font_size: float = 7.5, + ): self.cell_width = cell_width self.cell_height = cell_height self.font_size = font_size - def _shape(self, df): + def _shape(self, df: pd.DataFrame) -> Tuple[int, int]: """ Calculate table chape considering index levels. """ - row, col = df.shape return row + df.columns.nlevels, col + df.index.nlevels - def _get_cells(self, left, right, vertical): + def _get_cells(self, left, right, vertical) -> Tuple[int, int]: """ Calculate appropriate figure size based on left and right data. """ - if vertical: # calculate required number of cells vcells = max(sum(self._shape(l)[0] for l in left), self._shape(right)[0]) @@ -36,7 +41,7 @@ def _get_cells(self, left, right, vertical): hcells = sum([self._shape(l)[1] for l in left] + [self._shape(right)[1]]) return hcells, vcells - def plot(self, left, right, labels=None, vertical=True): + def plot(self, left, right, labels=None, vertical: bool = True): """ Plot left / right DataFrames in specified layout. @@ -45,7 +50,7 @@ def plot(self, left, right, labels=None, vertical=True): left : list of DataFrames before operation is applied right : DataFrame of operation result labels : list of str to be drawn as titles of left DataFrames - vertical : bool + vertical : bool, default True If True, use vertical layout. If False, use horizontal layout. """ import matplotlib.pyplot as plt @@ -96,7 +101,9 @@ def plot(self, left, right, labels=None, vertical=True): return fig def _conv(self, data): - """Convert each input to appropriate for table outplot""" + """ + Convert each input to appropriate for table outplot. + """ if isinstance(data, pd.Series): if data.name is None: data = data.to_frame(name="") @@ -127,7 +134,7 @@ def _insert_index(self, data): data.columns = col return data - def _make_table(self, ax, df, title, height=None): + def _make_table(self, ax, df, title: str, height: Optional[float] = None): if df is None: ax.set_visible(False) return diff --git a/pandas/util/_exceptions.py b/pandas/util/_exceptions.py index b8719154eb791..0723a37b1ba82 100644 --- a/pandas/util/_exceptions.py +++ b/pandas/util/_exceptions.py @@ -4,7 +4,9 @@ @contextlib.contextmanager def rewrite_exception(old_name: str, new_name: str): - """Rewrite the message of an exception.""" + """ + Rewrite the message of an exception. + """ try: yield except Exception as err: diff --git a/pandas/util/_print_versions.py b/pandas/util/_print_versions.py index b9f7e0c69f8b6..2801a2bf9c371 100644 --- a/pandas/util/_print_versions.py +++ b/pandas/util/_print_versions.py @@ -12,8 +12,9 @@ def get_sys_info() -> List[Tuple[str, Optional[Union[str, int]]]]: - "Returns system information as a list" - + """ + Returns system information as a list + """ blob: List[Tuple[str, Optional[Union[str, int]]]] = [] # get full commit hash @@ -123,7 +124,7 @@ def show_versions(as_json=False): print(tpl.format(k=k, stat=stat)) -def main(): +def main() -> int: from optparse import OptionParser parser = OptionParser() diff --git a/pandas/util/_test_decorators.py b/pandas/util/_test_decorators.py index 0e3ea25bf6fdb..7e14ed27d5bd4 100644 --- a/pandas/util/_test_decorators.py +++ b/pandas/util/_test_decorators.py @@ -37,7 +37,7 @@ def test_foo(): from pandas.core.computation.expressions import _NUMEXPR_INSTALLED, _USE_NUMEXPR -def safe_import(mod_name, min_version=None): +def safe_import(mod_name: str, min_version: Optional[str] = None): """ Parameters: ----------- @@ -110,7 +110,7 @@ def _skip_if_not_us_locale(): return True -def _skip_if_no_scipy(): +def _skip_if_no_scipy() -> bool: return not ( safe_import("scipy.stats") and safe_import("scipy.sparse") @@ -195,7 +195,9 @@ def skip_if_no(package: str, min_version: Optional[str] = None) -> Callable: ) -def skip_if_np_lt(ver_str, reason=None, *args, **kwds): +def skip_if_np_lt( + ver_str: str, reason: Optional[str] = None, *args, **kwds +) -> Callable: if reason is None: reason = f"NumPy {ver_str} or greater required" return pytest.mark.skipif( @@ -211,14 +213,14 @@ def parametrize_fixture_doc(*args): initial fixture docstring by replacing placeholders {0}, {1} etc with parameters passed as arguments. - Parameters: + Parameters ---------- - args: iterable - Positional arguments for docstring. + args: iterable + Positional arguments for docstring. - Returns: + Returns ------- - documented_fixture: function + function The decorated function wrapped within a pytest ``parametrize_fixture_doc`` mark """ @@ -230,7 +232,7 @@ def documented_fixture(fixture): return documented_fixture -def check_file_leaks(func): +def check_file_leaks(func) -> Callable: """ Decorate a test function tot check that we are not leaking file descriptors. """ diff --git a/pandas/util/_tester.py b/pandas/util/_tester.py index 6a3943cab692e..b299f3790ab22 100644 --- a/pandas/util/_tester.py +++ b/pandas/util/_tester.py @@ -1,5 +1,5 @@ """ -Entrypoint for testing from the top-level namespace +Entrypoint for testing from the top-level namespace. """ import os import sys @@ -22,7 +22,8 @@ def test(extra_args=None): extra_args = [extra_args] cmd = extra_args cmd += [PKG] - print(f"running: pytest {' '.join(cmd)}") + joined = " ".join(cmd) + print(f"running: pytest {joined}") sys.exit(pytest.main(cmd)) diff --git a/pandas/util/_validators.py b/pandas/util/_validators.py index 8b675a6b688fe..547fe748ae941 100644 --- a/pandas/util/_validators.py +++ b/pandas/util/_validators.py @@ -15,7 +15,6 @@ def _check_arg_length(fname, args, max_fname_arg_count, compat_args): Checks whether 'args' has length of at most 'compat_args'. Raises a TypeError if that is not the case, similar to in Python when a function is called with too many arguments. - """ if max_fname_arg_count < 0: raise ValueError("'max_fname_arg_count' must be non-negative") @@ -38,7 +37,6 @@ def _check_for_default_values(fname, arg_val_dict, compat_args): Note that this function is to be called only when it has been checked that arg_val_dict.keys() is a subset of compat_args - """ for key in arg_val_dict: # try checking equality directly with '=' operator, @@ -65,11 +63,8 @@ def _check_for_default_values(fname, arg_val_dict, compat_args): if not match: raise ValueError( - ( - f"the '{key}' parameter is not " - "supported in the pandas " - f"implementation of {fname}()" - ) + f"the '{key}' parameter is not supported in " + f"the pandas implementation of {fname}()" ) @@ -79,19 +74,18 @@ def validate_args(fname, args, max_fname_arg_count, compat_args): has at most `len(compat_args)` arguments and whether or not all of these elements in `args` are set to their default values. - fname: str + Parameters + ---------- + fname : str The name of the function being passed the `*args` parameter - - args: tuple + args : tuple The `*args` parameter passed into a function - - max_fname_arg_count: int + max_fname_arg_count : int The maximum number of arguments that the function `fname` can accept, excluding those in `args`. Used for displaying appropriate error messages. Must be non-negative. - - compat_args: OrderedDict - A ordered dictionary of keys and their associated default values. + compat_args : Dict + An ordered dictionary of keys and their associated default values. In order to accommodate buggy behaviour in some versions of `numpy`, where a signature displayed keyword arguments but then passed those arguments **positionally** internally when calling downstream @@ -101,10 +95,11 @@ def validate_args(fname, args, max_fname_arg_count, compat_args): Raises ------ - TypeError if `args` contains more values than there are `compat_args` - ValueError if `args` contains values that do not correspond to those - of the default values specified in `compat_args` - + TypeError + If `args` contains more values than there are `compat_args` + ValueError + If `args` contains values that do not correspond to those + of the default values specified in `compat_args` """ _check_arg_length(fname, args, max_fname_arg_count, compat_args) @@ -119,7 +114,6 @@ def _check_for_invalid_keys(fname, kwargs, compat_args): """ Checks whether 'kwargs' contains any keys that are not in 'compat_args' and raises a TypeError if there is one. - """ # set(dict) --> set of the dictionary's keys diff = set(kwargs) - set(compat_args) @@ -139,12 +133,10 @@ def validate_kwargs(fname, kwargs, compat_args): Parameters ---------- - fname: str + fname : str The name of the function being passed the `**kwargs` parameter - - kwargs: dict + kwargs : dict The `**kwargs` parameter passed into `fname` - compat_args: dict A dictionary of keys that `kwargs` is allowed to have and their associated default values @@ -154,7 +146,6 @@ def validate_kwargs(fname, kwargs, compat_args): TypeError if `kwargs` contains keys not in `compat_args` ValueError if `kwargs` contains keys in `compat_args` that do not map to the default values specified in `compat_args` - """ kwds = kwargs.copy() _check_for_invalid_keys(fname, kwargs, compat_args) @@ -171,18 +162,14 @@ def validate_args_and_kwargs(fname, args, kwargs, max_fname_arg_count, compat_ar ---------- fname: str The name of the function being passed the `**kwargs` parameter - args: tuple The `*args` parameter passed into a function - kwargs: dict The `**kwargs` parameter passed into `fname` - max_fname_arg_count: int The minimum number of arguments that the function `fname` requires, excluding those in `args`. Used for displaying appropriate error messages. Must be non-negative. - compat_args: OrderedDict A ordered dictionary of keys that `kwargs` is allowed to have and their associated default values. Note that if there diff --git a/pandas/util/testing.py b/pandas/util/testing.py index 6350b1075f4a0..c31cddc102afb 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -8,7 +8,7 @@ from shutil import rmtree import string import tempfile -from typing import Union, cast +from typing import List, Optional, Union, cast import warnings import zipfile @@ -22,6 +22,7 @@ ) import pandas._libs.testing as _testing +from pandas._typing import FrameOrSeries from pandas.compat import _get_lzma_file, _import_lzma from pandas.core.dtypes.common import ( @@ -97,11 +98,10 @@ def reset_display_options(): """ Reset the display options for printing and representing objects. """ - pd.reset_option("^display.", silent=True) -def round_trip_pickle(obj, path=None): +def round_trip_pickle(obj: FrameOrSeries, path: Optional[str] = None) -> FrameOrSeries: """ Pickle an object and then read it again. @@ -114,10 +114,9 @@ def round_trip_pickle(obj, path=None): Returns ------- - round_trip_pickled_object : pandas object + pandas object The original object that was pickled and then re-read. """ - if path is None: path = f"__{rands(10)}__.pickle" with ensure_clean(path) as path: @@ -125,7 +124,7 @@ def round_trip_pickle(obj, path=None): return pd.read_pickle(path) -def round_trip_pathlib(writer, reader, path=None): +def round_trip_pathlib(writer, reader, path: Optional[str] = None): """ Write an object to file specified by a pathlib.Path and read it back @@ -140,10 +139,9 @@ def round_trip_pathlib(writer, reader, path=None): Returns ------- - round_trip_object : pandas object + pandas object The original object that was serialized and then re-read. """ - import pytest Path = pytest.importorskip("pathlib").Path @@ -155,9 +153,9 @@ def round_trip_pathlib(writer, reader, path=None): return obj -def round_trip_localpath(writer, reader, path=None): +def round_trip_localpath(writer, reader, path: Optional[str] = None): """ - Write an object to file specified by a py.path LocalPath and read it back + Write an object to file specified by a py.path LocalPath and read it back. Parameters ---------- @@ -170,7 +168,7 @@ def round_trip_localpath(writer, reader, path=None): Returns ------- - round_trip_object : pandas object + pandas object The original object that was serialized and then re-read. """ import pytest @@ -187,21 +185,20 @@ def round_trip_localpath(writer, reader, path=None): @contextmanager def decompress_file(path, compression): """ - Open a compressed file and return a file object + Open a compressed file and return a file object. Parameters ---------- path : str - The path where the file is read from + The path where the file is read from. compression : {'gzip', 'bz2', 'zip', 'xz', None} Name of the decompression to use Returns ------- - f : file object + file object """ - if compression is None: f = open(path, "rb") elif compression == "gzip": @@ -247,7 +244,6 @@ def write_to_compressed(compression, path, data, dest="test"): ------ ValueError : An invalid compression value was passed in. """ - if compression == "zip": import zipfile @@ -279,7 +275,11 @@ def write_to_compressed(compression, path, data, dest="test"): def assert_almost_equal( - left, right, check_dtype="equiv", check_less_precise=False, **kwargs + left, + right, + check_dtype: Union[bool, str] = "equiv", + check_less_precise: Union[bool, int] = False, + **kwargs, ): """ Check that the left and right objects are approximately equal. @@ -306,7 +306,6 @@ def assert_almost_equal( compare the **ratio** of the second number to the first number and check whether it is equivalent to 1 within the specified precision. """ - if isinstance(left, pd.Index): assert_index_equal( left, @@ -389,13 +388,13 @@ def _check_isinstance(left, right, cls): ) -def assert_dict_equal(left, right, compare_keys=True): +def assert_dict_equal(left, right, compare_keys: bool = True): _check_isinstance(left, right, dict) _testing.assert_dict_equal(left, right, compare_keys=compare_keys) -def randbool(size=(), p=0.5): +def randbool(size=(), p: float = 0.5): return rand(*size) <= p @@ -407,7 +406,9 @@ def randbool(size=(), p=0.5): def rands_array(nchars, size, dtype="O"): - """Generate an array of byte strings.""" + """ + Generate an array of byte strings. + """ retval = ( np.random.choice(RANDS_CHARS, size=nchars * np.prod(size)) .view((np.str_, nchars)) @@ -420,7 +421,9 @@ def rands_array(nchars, size, dtype="O"): def randu_array(nchars, size, dtype="O"): - """Generate an array of unicode strings.""" + """ + Generate an array of unicode strings. + """ retval = ( np.random.choice(RANDU_CHARS, size=nchars * np.prod(size)) .view((np.unicode_, nchars)) @@ -468,7 +471,8 @@ def close(fignum=None): @contextmanager def ensure_clean(filename=None, return_filelike=False): - """Gets a temporary path and agrees to remove on close. + """ + Gets a temporary path and agrees to remove on close. Parameters ---------- @@ -553,8 +557,9 @@ def ensure_safe_environment_variables(): # Comparators -def equalContents(arr1, arr2): - """Checks if the set of unique elements of arr1 and arr2 are equivalent. +def equalContents(arr1, arr2) -> bool: + """ + Checks if the set of unique elements of arr1 and arr2 are equivalent. """ return frozenset(arr1) == frozenset(arr2) @@ -691,8 +696,10 @@ def _get_ilevel_values(index, level): assert_categorical_equal(left.values, right.values, obj=f"{obj} category") -def assert_class_equal(left, right, exact=True, obj="Input"): - """checks classes are equal.""" +def assert_class_equal(left, right, exact: Union[bool, str] = True, obj="Input"): + """ + Checks classes are equal. + """ __tracebackhide__ = True def repr_class(x): @@ -2641,8 +2648,9 @@ def _constructor(self): @contextmanager -def set_timezone(tz): - """Context manager for temporarily setting a timezone. +def set_timezone(tz: str): + """ + Context manager for temporarily setting a timezone. Parameters ---------- @@ -2685,7 +2693,8 @@ def setTZ(tz): def _make_skipna_wrapper(alternative, skipna_alternative=None): - """Create a function for calling on an array. + """ + Create a function for calling on an array. Parameters ---------- @@ -2697,7 +2706,7 @@ def _make_skipna_wrapper(alternative, skipna_alternative=None): Returns ------- - skipna_wrapper : function + function """ if skipna_alternative: @@ -2715,7 +2724,7 @@ def skipna_wrapper(x): return skipna_wrapper -def convert_rows_list_to_csv_str(rows_list): +def convert_rows_list_to_csv_str(rows_list: List[str]): """ Convert list of CSV rows to single CSV-formatted string for current OS. @@ -2723,13 +2732,13 @@ def convert_rows_list_to_csv_str(rows_list): Parameters ---------- - rows_list : list - The list of string. Each element represents the row of csv. + rows_list : List[str] + Each element represents the row of csv. Returns ------- - expected : string - Expected output of to_csv() in current OS + str + Expected output of to_csv() in current OS. """ sep = os.linesep expected = sep.join(rows_list) + sep