Skip to content

Commit 024eda5

Browse files
twoertweincbpygit
authored andcommitted
TYP: Fix some PythonParser and Plotting types (pandas-dev#56643)
* TYP: fix some annotations * pyupgrade
1 parent f556ea2 commit 024eda5

File tree

7 files changed

+106
-41
lines changed

7 files changed

+106
-41
lines changed

pandas/core/interchange/from_dataframe.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import ctypes
44
import re
5-
from typing import Any
5+
from typing import (
6+
Any,
7+
overload,
8+
)
69

710
import numpy as np
811

@@ -459,12 +462,42 @@ def buffer_to_ndarray(
459462
return np.array([], dtype=ctypes_type)
460463

461464

465+
@overload
466+
def set_nulls(
467+
data: np.ndarray,
468+
col: Column,
469+
validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None,
470+
allow_modify_inplace: bool = ...,
471+
) -> np.ndarray:
472+
...
473+
474+
475+
@overload
476+
def set_nulls(
477+
data: pd.Series,
478+
col: Column,
479+
validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None,
480+
allow_modify_inplace: bool = ...,
481+
) -> pd.Series:
482+
...
483+
484+
485+
@overload
486+
def set_nulls(
487+
data: np.ndarray | pd.Series,
488+
col: Column,
489+
validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None,
490+
allow_modify_inplace: bool = ...,
491+
) -> np.ndarray | pd.Series:
492+
...
493+
494+
462495
def set_nulls(
463496
data: np.ndarray | pd.Series,
464497
col: Column,
465498
validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None,
466499
allow_modify_inplace: bool = True,
467-
):
500+
) -> np.ndarray | pd.Series:
468501
"""
469502
Set null values for the data according to the column null kind.
470503

pandas/io/parsers/python_parser.py

+26-25
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@
44
abc,
55
defaultdict,
66
)
7-
from collections.abc import (
8-
Hashable,
9-
Iterator,
10-
Mapping,
11-
Sequence,
12-
)
137
import csv
148
from io import StringIO
159
import re
@@ -50,15 +44,24 @@
5044
)
5145

5246
if TYPE_CHECKING:
47+
from collections.abc import (
48+
Hashable,
49+
Iterator,
50+
Mapping,
51+
Sequence,
52+
)
53+
5354
from pandas._typing import (
5455
ArrayLike,
5556
ReadCsvBuffer,
5657
Scalar,
58+
T,
5759
)
5860

5961
from pandas import (
6062
Index,
6163
MultiIndex,
64+
Series,
6265
)
6366

6467
# BOM character (byte order mark)
@@ -77,7 +80,7 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds) -> None:
7780
"""
7881
super().__init__(kwds)
7982

80-
self.data: Iterator[str] | None = None
83+
self.data: Iterator[list[str]] | list[list[Scalar]] = []
8184
self.buf: list = []
8285
self.pos = 0
8386
self.line_pos = 0
@@ -116,10 +119,11 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds) -> None:
116119

117120
# Set self.data to something that can read lines.
118121
if isinstance(f, list):
119-
# read_excel: f is a list
120-
self.data = cast(Iterator[str], f)
122+
# read_excel: f is a nested list, can contain non-str
123+
self.data = f
121124
else:
122125
assert hasattr(f, "readline")
126+
# yields list of str
123127
self.data = self._make_reader(f)
124128

125129
# Get columns in two steps: infer from data, then
@@ -179,7 +183,7 @@ def num(self) -> re.Pattern:
179183
)
180184
return re.compile(regex)
181185

182-
def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]):
186+
def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> Iterator[list[str]]:
183187
sep = self.delimiter
184188

185189
if sep is None or len(sep) == 1:
@@ -246,7 +250,9 @@ def _read():
246250
def read(
247251
self, rows: int | None = None
248252
) -> tuple[
249-
Index | None, Sequence[Hashable] | MultiIndex, Mapping[Hashable, ArrayLike]
253+
Index | None,
254+
Sequence[Hashable] | MultiIndex,
255+
Mapping[Hashable, ArrayLike | Series],
250256
]:
251257
try:
252258
content = self._get_lines(rows)
@@ -326,7 +332,9 @@ def _exclude_implicit_index(
326332
def get_chunk(
327333
self, size: int | None = None
328334
) -> tuple[
329-
Index | None, Sequence[Hashable] | MultiIndex, Mapping[Hashable, ArrayLike]
335+
Index | None,
336+
Sequence[Hashable] | MultiIndex,
337+
Mapping[Hashable, ArrayLike | Series],
330338
]:
331339
if size is None:
332340
# error: "PythonParser" has no attribute "chunksize"
@@ -689,7 +697,7 @@ def _check_for_bom(self, first_row: list[Scalar]) -> list[Scalar]:
689697
new_row_list: list[Scalar] = [new_row]
690698
return new_row_list + first_row[1:]
691699

692-
def _is_line_empty(self, line: list[Scalar]) -> bool:
700+
def _is_line_empty(self, line: Sequence[Scalar]) -> bool:
693701
"""
694702
Check if a line is empty or not.
695703
@@ -730,8 +738,6 @@ def _next_line(self) -> list[Scalar]:
730738
else:
731739
while self.skipfunc(self.pos):
732740
self.pos += 1
733-
# assert for mypy, data is Iterator[str] or None, would error in next
734-
assert self.data is not None
735741
next(self.data)
736742

737743
while True:
@@ -800,12 +806,10 @@ def _next_iter_line(self, row_num: int) -> list[Scalar] | None:
800806
The row number of the line being parsed.
801807
"""
802808
try:
803-
# assert for mypy, data is Iterator[str] or None, would error in next
804-
assert self.data is not None
809+
assert not isinstance(self.data, list)
805810
line = next(self.data)
806-
# for mypy
807-
assert isinstance(line, list)
808-
return line
811+
# lie about list[str] vs list[Scalar] to minimize ignores
812+
return line # type: ignore[return-value]
809813
except csv.Error as e:
810814
if self.on_bad_lines in (
811815
self.BadLineHandleMethod.ERROR,
@@ -855,7 +859,7 @@ def _check_comments(self, lines: list[list[Scalar]]) -> list[list[Scalar]]:
855859
ret.append(rl)
856860
return ret
857861

858-
def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]:
862+
def _remove_empty_lines(self, lines: list[list[T]]) -> list[list[T]]:
859863
"""
860864
Iterate through the lines and remove any that are
861865
either empty or contain only one whitespace value
@@ -1121,9 +1125,6 @@ def _get_lines(self, rows: int | None = None) -> list[list[Scalar]]:
11211125
row_ct = 0
11221126
offset = self.pos if self.pos is not None else 0
11231127
while row_ct < rows:
1124-
# assert for mypy, data is Iterator[str] or None, would
1125-
# error in next
1126-
assert self.data is not None
11271128
new_row = next(self.data)
11281129
if not self.skipfunc(offset + row_index):
11291130
row_ct += 1
@@ -1338,7 +1339,7 @@ def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> FixedWidthReader:
13381339
self.infer_nrows,
13391340
)
13401341

1341-
def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]:
1342+
def _remove_empty_lines(self, lines: list[list[T]]) -> list[list[T]]:
13421343
"""
13431344
Returns the list of lines without the empty ones. With fixed-width
13441345
fields, empty lines become arrays of empty strings.

pandas/plotting/_matplotlib/boxplot.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,8 @@ def _get_colors():
371371
# num_colors=3 is required as method maybe_color_bp takes the colors
372372
# in positions 0 and 2.
373373
# if colors not provided, use same defaults as DataFrame.plot.box
374-
result = get_standard_colors(num_colors=3)
375-
result = np.take(result, [0, 0, 2])
374+
result_list = get_standard_colors(num_colors=3)
375+
result = np.take(result_list, [0, 0, 2])
376376
result = np.append(result, "k")
377377

378378
colors = kwds.pop("color", None)

pandas/plotting/_matplotlib/hist.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -457,10 +457,8 @@ def hist_series(
457457
ax.grid(grid)
458458
axes = np.array([ax])
459459

460-
# error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any,
461-
# dtype[Any]]"; expected "Axes | Sequence[Axes]"
462460
set_ticks_props(
463-
axes, # type: ignore[arg-type]
461+
axes,
464462
xlabelsize=xlabelsize,
465463
xrot=xrot,
466464
ylabelsize=ylabelsize,

pandas/plotting/_matplotlib/style.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from collections.abc import (
44
Collection,
55
Iterator,
6+
Sequence,
67
)
78
import itertools
89
from typing import (
910
TYPE_CHECKING,
1011
cast,
12+
overload,
1113
)
1214
import warnings
1315

@@ -26,12 +28,46 @@
2628
from matplotlib.colors import Colormap
2729

2830

31+
@overload
32+
def get_standard_colors(
33+
num_colors: int,
34+
colormap: Colormap | None = ...,
35+
color_type: str = ...,
36+
*,
37+
color: dict[str, Color],
38+
) -> dict[str, Color]:
39+
...
40+
41+
42+
@overload
43+
def get_standard_colors(
44+
num_colors: int,
45+
colormap: Colormap | None = ...,
46+
color_type: str = ...,
47+
*,
48+
color: Color | Sequence[Color] | None = ...,
49+
) -> list[Color]:
50+
...
51+
52+
53+
@overload
54+
def get_standard_colors(
55+
num_colors: int,
56+
colormap: Colormap | None = ...,
57+
color_type: str = ...,
58+
*,
59+
color: dict[str, Color] | Color | Sequence[Color] | None = ...,
60+
) -> dict[str, Color] | list[Color]:
61+
...
62+
63+
2964
def get_standard_colors(
3065
num_colors: int,
3166
colormap: Colormap | None = None,
3267
color_type: str = "default",
33-
color: dict[str, Color] | Color | Collection[Color] | None = None,
34-
):
68+
*,
69+
color: dict[str, Color] | Color | Sequence[Color] | None = None,
70+
) -> dict[str, Color] | list[Color]:
3571
"""
3672
Get standard colors based on `colormap`, `color_type` or `color` inputs.
3773

pandas/plotting/_matplotlib/tools.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
)
2020

2121
if TYPE_CHECKING:
22-
from collections.abc import (
23-
Iterable,
24-
Sequence,
25-
)
22+
from collections.abc import Iterable
2623

2724
from matplotlib.axes import Axes
2825
from matplotlib.axis import Axis
@@ -442,7 +439,7 @@ def handle_shared_axes(
442439
_remove_labels_from_axis(ax.yaxis)
443440

444441

445-
def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray:
442+
def flatten_axes(axes: Axes | Iterable[Axes]) -> np.ndarray:
446443
if not is_list_like(axes):
447444
return np.array([axes])
448445
elif isinstance(axes, (np.ndarray, ABCIndex)):
@@ -451,7 +448,7 @@ def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray:
451448

452449

453450
def set_ticks_props(
454-
axes: Axes | Sequence[Axes],
451+
axes: Axes | Iterable[Axes],
455452
xlabelsize: int | None = None,
456453
xrot=None,
457454
ylabelsize: int | None = None,

pyright_reportGeneralTypeIssues.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@
9999
"pandas/io/parsers/base_parser.py",
100100
"pandas/io/parsers/c_parser_wrapper.py",
101101
"pandas/io/pytables.py",
102-
"pandas/io/sas/sas_xport.py",
103102
"pandas/io/sql.py",
104103
"pandas/io/stata.py",
105104
"pandas/plotting/_matplotlib/boxplot.py",
106105
"pandas/plotting/_matplotlib/core.py",
106+
"pandas/plotting/_matplotlib/misc.py",
107107
"pandas/plotting/_matplotlib/timeseries.py",
108108
"pandas/plotting/_matplotlib/tools.py",
109109
"pandas/tseries/frequencies.py",

0 commit comments

Comments
 (0)