Skip to content

TYP: Fix some PythonParser and Plotting types #56643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions pandas/core/interchange/from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import ctypes
import re
from typing import Any
from typing import (
Any,
overload,
)

import numpy as np

Expand Down Expand Up @@ -459,12 +462,42 @@ def buffer_to_ndarray(
return np.array([], dtype=ctypes_type)


@overload
def set_nulls(
data: np.ndarray,
col: Column,
validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None,
allow_modify_inplace: bool = ...,
) -> np.ndarray:
...


@overload
def set_nulls(
data: pd.Series,
col: Column,
validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None,
allow_modify_inplace: bool = ...,
) -> pd.Series:
...


@overload
def set_nulls(
data: np.ndarray | pd.Series,
col: Column,
validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None,
allow_modify_inplace: bool = ...,
) -> np.ndarray | pd.Series:
...


def set_nulls(
data: np.ndarray | pd.Series,
col: Column,
validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None,
allow_modify_inplace: bool = True,
):
) -> np.ndarray | pd.Series:
"""
Set null values for the data according to the column null kind.

Expand Down
51 changes: 26 additions & 25 deletions pandas/io/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
abc,
defaultdict,
)
from collections.abc import (
Hashable,
Iterator,
Mapping,
Sequence,
)
import csv
from io import StringIO
import re
Expand Down Expand Up @@ -50,15 +44,24 @@
)

if TYPE_CHECKING:
from collections.abc import (
Hashable,
Iterator,
Mapping,
Sequence,
)

from pandas._typing import (
ArrayLike,
ReadCsvBuffer,
Scalar,
T,
)

from pandas import (
Index,
MultiIndex,
Series,
)

# BOM character (byte order mark)
Expand All @@ -77,7 +80,7 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds) -> None:
"""
super().__init__(kwds)

self.data: Iterator[str] | None = None
self.data: Iterator[list[str]] | list[list[Scalar]] = []
self.buf: list = []
self.pos = 0
self.line_pos = 0
Expand Down Expand Up @@ -116,10 +119,11 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds) -> None:

# Set self.data to something that can read lines.
if isinstance(f, list):
# read_excel: f is a list
self.data = cast(Iterator[str], f)
# read_excel: f is a nested list, can contain non-str
self.data = f
else:
assert hasattr(f, "readline")
# yields list of str
self.data = self._make_reader(f)

# Get columns in two steps: infer from data, then
Expand Down Expand Up @@ -179,7 +183,7 @@ def num(self) -> re.Pattern:
)
return re.compile(regex)

def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]):
def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> Iterator[list[str]]:
sep = self.delimiter

if sep is None or len(sep) == 1:
Expand Down Expand Up @@ -246,7 +250,9 @@ def _read():
def read(
self, rows: int | None = None
) -> tuple[
Index | None, Sequence[Hashable] | MultiIndex, Mapping[Hashable, ArrayLike]
Index | None,
Sequence[Hashable] | MultiIndex,
Mapping[Hashable, ArrayLike | Series],
]:
try:
content = self._get_lines(rows)
Expand Down Expand Up @@ -326,7 +332,9 @@ def _exclude_implicit_index(
def get_chunk(
self, size: int | None = None
) -> tuple[
Index | None, Sequence[Hashable] | MultiIndex, Mapping[Hashable, ArrayLike]
Index | None,
Sequence[Hashable] | MultiIndex,
Mapping[Hashable, ArrayLike | Series],
]:
if size is None:
# error: "PythonParser" has no attribute "chunksize"
Expand Down Expand Up @@ -689,7 +697,7 @@ def _check_for_bom(self, first_row: list[Scalar]) -> list[Scalar]:
new_row_list: list[Scalar] = [new_row]
return new_row_list + first_row[1:]

def _is_line_empty(self, line: list[Scalar]) -> bool:
def _is_line_empty(self, line: Sequence[Scalar]) -> bool:
"""
Check if a line is empty or not.

Expand Down Expand Up @@ -730,8 +738,6 @@ def _next_line(self) -> list[Scalar]:
else:
while self.skipfunc(self.pos):
self.pos += 1
# assert for mypy, data is Iterator[str] or None, would error in next
assert self.data is not None
next(self.data)

while True:
Expand Down Expand Up @@ -800,12 +806,10 @@ def _next_iter_line(self, row_num: int) -> list[Scalar] | None:
The row number of the line being parsed.
"""
try:
# assert for mypy, data is Iterator[str] or None, would error in next
assert self.data is not None
assert not isinstance(self.data, list)
line = next(self.data)
# for mypy
assert isinstance(line, list)
return line
# lie about list[str] vs list[Scalar] to minimize ignores
return line # type: ignore[return-value]
except csv.Error as e:
if self.on_bad_lines in (
self.BadLineHandleMethod.ERROR,
Expand Down Expand Up @@ -855,7 +859,7 @@ def _check_comments(self, lines: list[list[Scalar]]) -> list[list[Scalar]]:
ret.append(rl)
return ret

def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]:
def _remove_empty_lines(self, lines: list[list[T]]) -> list[list[T]]:
"""
Iterate through the lines and remove any that are
either empty or contain only one whitespace value
Expand Down Expand Up @@ -1121,9 +1125,6 @@ def _get_lines(self, rows: int | None = None) -> list[list[Scalar]]:
row_ct = 0
offset = self.pos if self.pos is not None else 0
while row_ct < rows:
# assert for mypy, data is Iterator[str] or None, would
# error in next
assert self.data is not None
new_row = next(self.data)
if not self.skipfunc(offset + row_index):
row_ct += 1
Expand Down Expand Up @@ -1338,7 +1339,7 @@ def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> FixedWidthReader:
self.infer_nrows,
)

def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]:
def _remove_empty_lines(self, lines: list[list[T]]) -> list[list[T]]:
"""
Returns the list of lines without the empty ones. With fixed-width
fields, empty lines become arrays of empty strings.
Expand Down
4 changes: 2 additions & 2 deletions pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ def _get_colors():
# num_colors=3 is required as method maybe_color_bp takes the colors
# in positions 0 and 2.
# if colors not provided, use same defaults as DataFrame.plot.box
result = get_standard_colors(num_colors=3)
result = np.take(result, [0, 0, 2])
result_list = get_standard_colors(num_colors=3)
result = np.take(result_list, [0, 0, 2])
result = np.append(result, "k")

colors = kwds.pop("color", None)
Expand Down
4 changes: 1 addition & 3 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,8 @@ def hist_series(
ax.grid(grid)
axes = np.array([ax])

# error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any,
# dtype[Any]]"; expected "Axes | Sequence[Axes]"
set_ticks_props(
axes, # type: ignore[arg-type]
axes,
xlabelsize=xlabelsize,
xrot=xrot,
ylabelsize=ylabelsize,
Expand Down
40 changes: 38 additions & 2 deletions pandas/plotting/_matplotlib/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from collections.abc import (
Collection,
Iterator,
Sequence,
)
import itertools
from typing import (
TYPE_CHECKING,
cast,
overload,
)
import warnings

Expand All @@ -26,12 +28,46 @@
from matplotlib.colors import Colormap


@overload
def get_standard_colors(
num_colors: int,
colormap: Colormap | None = ...,
color_type: str = ...,
*,
color: dict[str, Color],
) -> dict[str, Color]:
...


@overload
def get_standard_colors(
num_colors: int,
colormap: Colormap | None = ...,
color_type: str = ...,
*,
color: Color | Sequence[Color] | None = ...,
) -> list[Color]:
...


@overload
def get_standard_colors(
num_colors: int,
colormap: Colormap | None = ...,
color_type: str = ...,
*,
color: dict[str, Color] | Color | Sequence[Color] | None = ...,
) -> dict[str, Color] | list[Color]:
...


def get_standard_colors(
num_colors: int,
colormap: Colormap | None = None,
color_type: str = "default",
color: dict[str, Color] | Color | Collection[Color] | None = None,
):
*,
color: dict[str, Color] | Color | Sequence[Color] | None = None,
) -> dict[str, Color] | list[Color]:
"""
Get standard colors based on `colormap`, `color_type` or `color` inputs.

Expand Down
9 changes: 3 additions & 6 deletions pandas/plotting/_matplotlib/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
)

if TYPE_CHECKING:
from collections.abc import (
Iterable,
Sequence,
)
from collections.abc import Iterable

from matplotlib.axes import Axes
from matplotlib.axis import Axis
Expand Down Expand Up @@ -442,7 +439,7 @@ def handle_shared_axes(
_remove_labels_from_axis(ax.yaxis)


def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray:
def flatten_axes(axes: Axes | Iterable[Axes]) -> np.ndarray:
if not is_list_like(axes):
return np.array([axes])
elif isinstance(axes, (np.ndarray, ABCIndex)):
Expand All @@ -451,7 +448,7 @@ def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray:


def set_ticks_props(
axes: Axes | Sequence[Axes],
axes: Axes | Iterable[Axes],
xlabelsize: int | None = None,
xrot=None,
ylabelsize: int | None = None,
Expand Down
2 changes: 1 addition & 1 deletion pyright_reportGeneralTypeIssues.json
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@
"pandas/io/parsers/base_parser.py",
"pandas/io/parsers/c_parser_wrapper.py",
"pandas/io/pytables.py",
"pandas/io/sas/sas_xport.py",
"pandas/io/sql.py",
"pandas/io/stata.py",
"pandas/plotting/_matplotlib/boxplot.py",
"pandas/plotting/_matplotlib/core.py",
"pandas/plotting/_matplotlib/misc.py",
"pandas/plotting/_matplotlib/timeseries.py",
"pandas/plotting/_matplotlib/tools.py",
"pandas/tseries/frequencies.py",
Expand Down