From 3f0bf4c3719e59fff0621c7e293152c8c4d184e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Wed, 27 Dec 2023 16:22:30 -0500 Subject: [PATCH 1/2] TYP: fix some annotations --- pandas/core/interchange/from_dataframe.py | 37 +++++++++++++++- pandas/io/parsers/python_parser.py | 51 ++++++++++++----------- pandas/plotting/_matplotlib/boxplot.py | 4 +- pandas/plotting/_matplotlib/hist.py | 4 +- pandas/plotting/_matplotlib/style.py | 40 +++++++++++++++++- pandas/plotting/_matplotlib/tools.py | 9 ++-- pyright_reportGeneralTypeIssues.json | 2 +- 7 files changed, 106 insertions(+), 41 deletions(-) diff --git a/pandas/core/interchange/from_dataframe.py b/pandas/core/interchange/from_dataframe.py index d45ae37890ba7..73f492c83c2ff 100644 --- a/pandas/core/interchange/from_dataframe.py +++ b/pandas/core/interchange/from_dataframe.py @@ -2,7 +2,10 @@ import ctypes import re -from typing import Any +from typing import ( + Any, + overload, +) import numpy as np @@ -459,12 +462,42 @@ def buffer_to_ndarray( return np.array([], dtype=ctypes_type) +@overload +def set_nulls( + data: np.ndarray, + col: Column, + validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None, + allow_modify_inplace: bool = ..., +) -> np.ndarray: + ... + + +@overload +def set_nulls( + data: pd.Series, + col: Column, + validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None, + allow_modify_inplace: bool = ..., +) -> pd.Series: + ... + + +@overload +def set_nulls( + data: np.ndarray | pd.Series, + col: Column, + validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None, + allow_modify_inplace: bool = ..., +) -> np.ndarray | pd.Series: + ... + + def set_nulls( data: np.ndarray | pd.Series, col: Column, validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None, allow_modify_inplace: bool = True, -): +) -> np.ndarray | pd.Series: """ Set null values for the data according to the column null kind. diff --git a/pandas/io/parsers/python_parser.py b/pandas/io/parsers/python_parser.py index 79e7554a5744c..c1880eb815032 100644 --- a/pandas/io/parsers/python_parser.py +++ b/pandas/io/parsers/python_parser.py @@ -4,12 +4,6 @@ abc, defaultdict, ) -from collections.abc import ( - Hashable, - Iterator, - Mapping, - Sequence, -) import csv from io import StringIO import re @@ -50,15 +44,24 @@ ) if TYPE_CHECKING: + from collections.abc import ( + Hashable, + Iterator, + Mapping, + Sequence, + ) + from pandas._typing import ( ArrayLike, ReadCsvBuffer, Scalar, + T, ) from pandas import ( Index, MultiIndex, + Series, ) # BOM character (byte order mark) @@ -77,7 +80,7 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds) -> None: """ super().__init__(kwds) - self.data: Iterator[str] | None = None + self.data: Iterator[list[str]] | list[list[Scalar]] = [] self.buf: list = [] self.pos = 0 self.line_pos = 0 @@ -116,10 +119,11 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds) -> None: # Set self.data to something that can read lines. if isinstance(f, list): - # read_excel: f is a list - self.data = cast(Iterator[str], f) + # read_excel: f is a nested list, can contain non-str + self.data = f else: assert hasattr(f, "readline") + # yields list of str self.data = self._make_reader(f) # Get columns in two steps: infer from data, then @@ -179,7 +183,7 @@ def num(self) -> re.Pattern: ) return re.compile(regex) - def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]): + def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> Iterator[list[str]]: sep = self.delimiter if sep is None or len(sep) == 1: @@ -246,7 +250,9 @@ def _read(): def read( self, rows: int | None = None ) -> tuple[ - Index | None, Sequence[Hashable] | MultiIndex, Mapping[Hashable, ArrayLike] + Index | None, + Sequence[Hashable] | MultiIndex, + Mapping[Hashable, ArrayLike | Series], ]: try: content = self._get_lines(rows) @@ -326,7 +332,9 @@ def _exclude_implicit_index( def get_chunk( self, size: int | None = None ) -> tuple[ - Index | None, Sequence[Hashable] | MultiIndex, Mapping[Hashable, ArrayLike] + Index | None, + Sequence[Hashable] | MultiIndex, + Mapping[Hashable, ArrayLike | Series], ]: if size is None: # error: "PythonParser" has no attribute "chunksize" @@ -689,7 +697,7 @@ def _check_for_bom(self, first_row: list[Scalar]) -> list[Scalar]: new_row_list: list[Scalar] = [new_row] return new_row_list + first_row[1:] - def _is_line_empty(self, line: list[Scalar]) -> bool: + def _is_line_empty(self, line: Sequence[Scalar]) -> bool: """ Check if a line is empty or not. @@ -730,8 +738,6 @@ def _next_line(self) -> list[Scalar]: else: while self.skipfunc(self.pos): self.pos += 1 - # assert for mypy, data is Iterator[str] or None, would error in next - assert self.data is not None next(self.data) while True: @@ -800,12 +806,10 @@ def _next_iter_line(self, row_num: int) -> list[Scalar] | None: The row number of the line being parsed. """ try: - # assert for mypy, data is Iterator[str] or None, would error in next - assert self.data is not None + assert not isinstance(self.data, list) line = next(self.data) - # for mypy - assert isinstance(line, list) - return line + # lie about list[str] vs list[Scalar] to minimize ignores + return line # type: ignore[return-value] except csv.Error as e: if self.on_bad_lines in ( self.BadLineHandleMethod.ERROR, @@ -855,7 +859,7 @@ def _check_comments(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: ret.append(rl) return ret - def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: + def _remove_empty_lines(self, lines: list[list[T]]) -> list[list[T]]: """ Iterate through the lines and remove any that are either empty or contain only one whitespace value @@ -1121,9 +1125,6 @@ def _get_lines(self, rows: int | None = None) -> list[list[Scalar]]: row_ct = 0 offset = self.pos if self.pos is not None else 0 while row_ct < rows: - # assert for mypy, data is Iterator[str] or None, would - # error in next - assert self.data is not None new_row = next(self.data) if not self.skipfunc(offset + row_index): row_ct += 1 @@ -1338,7 +1339,7 @@ def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> FixedWidthReader: self.infer_nrows, ) - def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: + def _remove_empty_lines(self, lines: list[list[T]]) -> list[list[T]]: """ Returns the list of lines without the empty ones. With fixed-width fields, empty lines become arrays of empty strings. diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index d2b76decaa75d..084452ec23719 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -371,8 +371,8 @@ def _get_colors(): # num_colors=3 is required as method maybe_color_bp takes the colors # in positions 0 and 2. # if colors not provided, use same defaults as DataFrame.plot.box - result = get_standard_colors(num_colors=3) - result = np.take(result, [0, 0, 2]) + result_list = get_standard_colors(num_colors=3) + result = np.take(result_list, [0, 0, 2]) result = np.append(result, "k") colors = kwds.pop("color", None) diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index e610f1adb602c..898abc9b78e3f 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -457,10 +457,8 @@ def hist_series( ax.grid(grid) axes = np.array([ax]) - # error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any, - # dtype[Any]]"; expected "Axes | Sequence[Axes]" set_ticks_props( - axes, # type: ignore[arg-type] + axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, diff --git a/pandas/plotting/_matplotlib/style.py b/pandas/plotting/_matplotlib/style.py index bf4e4be3bfd82..5778cb0ab5695 100644 --- a/pandas/plotting/_matplotlib/style.py +++ b/pandas/plotting/_matplotlib/style.py @@ -7,7 +7,9 @@ import itertools from typing import ( TYPE_CHECKING, + Sequence, cast, + overload, ) import warnings @@ -26,12 +28,46 @@ from matplotlib.colors import Colormap +@overload +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = ..., + color_type: str = ..., + *, + color: dict[str, Color], +) -> dict[str, Color]: + ... + + +@overload +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = ..., + color_type: str = ..., + *, + color: Color | Sequence[Color] | None = ..., +) -> list[Color]: + ... + + +@overload +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = ..., + color_type: str = ..., + *, + color: dict[str, Color] | Color | Sequence[Color] | None = ..., +) -> dict[str, Color] | list[Color]: + ... + + def get_standard_colors( num_colors: int, colormap: Colormap | None = None, color_type: str = "default", - color: dict[str, Color] | Color | Collection[Color] | None = None, -): + *, + color: dict[str, Color] | Color | Sequence[Color] | None = None, +) -> dict[str, Color] | list[Color]: """ Get standard colors based on `colormap`, `color_type` or `color` inputs. diff --git a/pandas/plotting/_matplotlib/tools.py b/pandas/plotting/_matplotlib/tools.py index 898b5b25e7b01..89a8a7cf79719 100644 --- a/pandas/plotting/_matplotlib/tools.py +++ b/pandas/plotting/_matplotlib/tools.py @@ -19,10 +19,7 @@ ) if TYPE_CHECKING: - from collections.abc import ( - Iterable, - Sequence, - ) + from collections.abc import Iterable from matplotlib.axes import Axes from matplotlib.axis import Axis @@ -442,7 +439,7 @@ def handle_shared_axes( _remove_labels_from_axis(ax.yaxis) -def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray: +def flatten_axes(axes: Axes | Iterable[Axes]) -> np.ndarray: if not is_list_like(axes): return np.array([axes]) elif isinstance(axes, (np.ndarray, ABCIndex)): @@ -451,7 +448,7 @@ def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray: def set_ticks_props( - axes: Axes | Sequence[Axes], + axes: Axes | Iterable[Axes], xlabelsize: int | None = None, xrot=None, ylabelsize: int | None = None, diff --git a/pyright_reportGeneralTypeIssues.json b/pyright_reportGeneralTypeIssues.json index a38343d6198ae..da27906e041cf 100644 --- a/pyright_reportGeneralTypeIssues.json +++ b/pyright_reportGeneralTypeIssues.json @@ -99,11 +99,11 @@ "pandas/io/parsers/base_parser.py", "pandas/io/parsers/c_parser_wrapper.py", "pandas/io/pytables.py", - "pandas/io/sas/sas_xport.py", "pandas/io/sql.py", "pandas/io/stata.py", "pandas/plotting/_matplotlib/boxplot.py", "pandas/plotting/_matplotlib/core.py", + "pandas/plotting/_matplotlib/misc.py", "pandas/plotting/_matplotlib/timeseries.py", "pandas/plotting/_matplotlib/tools.py", "pandas/tseries/frequencies.py", From 15cb419819fe90e433f6966291d15183bfd7ca3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Wed, 27 Dec 2023 16:45:09 -0500 Subject: [PATCH 2/2] pyupgrade --- pandas/plotting/_matplotlib/style.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/plotting/_matplotlib/style.py b/pandas/plotting/_matplotlib/style.py index 5778cb0ab5695..45a077a6151cf 100644 --- a/pandas/plotting/_matplotlib/style.py +++ b/pandas/plotting/_matplotlib/style.py @@ -3,11 +3,11 @@ from collections.abc import ( Collection, Iterator, + Sequence, ) import itertools from typing import ( TYPE_CHECKING, - Sequence, cast, overload, )