Skip to content

Commit d734496

Browse files
authored
TYP: plotting (pandas-dev#55887)
* TYP: _iter_data * TYP: plotting * TYP: plotting * TYP: plotting * Improve check * TYP: plotting * lint fixup * mypy fixup * pyright fixup
1 parent 6755b81 commit d734496

File tree

6 files changed

+99
-65
lines changed

6 files changed

+99
-65
lines changed

pandas/plotting/_core.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@
3737

3838
from pandas._typing import IndexLabel
3939

40-
from pandas import DataFrame
40+
from pandas import (
41+
DataFrame,
42+
Series,
43+
)
44+
from pandas.core.groupby.generic import DataFrameGroupBy
4145

4246

4347
def hist_series(
44-
self,
48+
self: Series,
4549
by=None,
4650
ax=None,
4751
grid: bool = True,
@@ -512,7 +516,7 @@ def boxplot(
512516
@Substitution(data="", backend=_backend_doc)
513517
@Appender(_boxplot_doc)
514518
def boxplot_frame(
515-
self,
519+
self: DataFrame,
516520
column=None,
517521
by=None,
518522
ax=None,
@@ -542,7 +546,7 @@ def boxplot_frame(
542546

543547

544548
def boxplot_frame_groupby(
545-
grouped,
549+
grouped: DataFrameGroupBy,
546550
subplots: bool = True,
547551
column=None,
548552
fontsize: int | None = None,
@@ -843,11 +847,11 @@ class PlotAccessor(PandasObject):
843847
_kind_aliases = {"density": "kde"}
844848
_all_kinds = _common_kinds + _series_kinds + _dataframe_kinds
845849

846-
def __init__(self, data) -> None:
850+
def __init__(self, data: Series | DataFrame) -> None:
847851
self._parent = data
848852

849853
@staticmethod
850-
def _get_call_args(backend_name: str, data, args, kwargs):
854+
def _get_call_args(backend_name: str, data: Series | DataFrame, args, kwargs):
851855
"""
852856
This function makes calls to this accessor `__call__` method compatible
853857
with the previous `SeriesPlotMethods.__call__` and

pandas/plotting/_matplotlib/boxplot.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,18 @@ def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
9393
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
9494
@classmethod
9595
def _plot( # type: ignore[override]
96-
cls, ax: Axes, y, column_num=None, return_type: str = "axes", **kwds
96+
cls, ax: Axes, y: np.ndarray, column_num=None, return_type: str = "axes", **kwds
9797
):
98+
ys: np.ndarray | list[np.ndarray]
9899
if y.ndim == 2:
99-
y = [remove_na_arraylike(v) for v in y]
100+
ys = [remove_na_arraylike(v) for v in y]
100101
# Boxplot fails with empty arrays, so need to add a NaN
101102
# if any cols are empty
102103
# GH 8181
103-
y = [v if v.size > 0 else np.array([np.nan]) for v in y]
104+
ys = [v if v.size > 0 else np.array([np.nan]) for v in ys]
104105
else:
105-
y = remove_na_arraylike(y)
106-
bp = ax.boxplot(y, **kwds)
106+
ys = remove_na_arraylike(y)
107+
bp = ax.boxplot(ys, **kwds)
107108

108109
if return_type == "dict":
109110
return bp, bp
@@ -240,8 +241,7 @@ def _make_plot(self, fig: Figure) -> None:
240241
self.maybe_color_bp(bp)
241242
self._return_obj = ret
242243

243-
labels = [left for left, _ in self._iter_data()]
244-
labels = [pprint_thing(left) for left in labels]
244+
labels = [pprint_thing(left) for left in self.data.columns]
245245
if not self.use_index:
246246
labels = [pprint_thing(key) for key in range(len(labels))]
247247
_set_ticklabels(
@@ -251,7 +251,7 @@ def _make_plot(self, fig: Figure) -> None:
251251
def _make_legend(self) -> None:
252252
pass
253253

254-
def _post_plot_logic(self, ax, data) -> None:
254+
def _post_plot_logic(self, ax: Axes, data) -> None:
255255
# GH 45465: make sure that the boxplot doesn't ignore xlabel/ylabel
256256
if self.xlabel:
257257
ax.set_xlabel(pprint_thing(self.xlabel))

pandas/plotting/_matplotlib/core.py

+62-35
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections.abc import (
88
Hashable,
99
Iterable,
10+
Iterator,
1011
Sequence,
1112
)
1213
from typing import (
@@ -431,17 +432,15 @@ def _validate_color_args(self):
431432
)
432433

433434
@final
434-
def _iter_data(self, data=None, keep_index: bool = False, fillna=None):
435-
if data is None:
436-
data = self.data
437-
if fillna is not None:
438-
data = data.fillna(fillna)
439-
435+
@staticmethod
436+
def _iter_data(
437+
data: DataFrame | dict[Hashable, Series | DataFrame]
438+
) -> Iterator[tuple[Hashable, np.ndarray]]:
440439
for col, values in data.items():
441-
if keep_index is True:
442-
yield col, values
443-
else:
444-
yield col, values.values
440+
# This was originally written to use values.values before EAs
441+
# were implemented; adding np.asarray(...) to keep consistent
442+
# typing.
443+
yield col, np.asarray(values.values)
445444

446445
@property
447446
def nseries(self) -> int:
@@ -480,7 +479,7 @@ def _has_plotted_object(ax: Axes) -> bool:
480479
return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0
481480

482481
@final
483-
def _maybe_right_yaxis(self, ax: Axes, axes_num: int):
482+
def _maybe_right_yaxis(self, ax: Axes, axes_num: int) -> Axes:
484483
if not self.on_right(axes_num):
485484
# secondary axes may be passed via ax kw
486485
return self._get_ax_layer(ax)
@@ -656,11 +655,7 @@ def _compute_plot_data(self):
656655

657656
numeric_data = data.select_dtypes(include=include_type, exclude=exclude_type)
658657

659-
try:
660-
is_empty = numeric_data.columns.empty
661-
except AttributeError:
662-
is_empty = not len(numeric_data)
663-
658+
is_empty = numeric_data.shape[-1] == 0
664659
# no non-numeric frames or series allowed
665660
if is_empty:
666661
raise TypeError("no numeric data to plot")
@@ -682,7 +677,7 @@ def _add_table(self) -> None:
682677
tools.table(ax, data)
683678

684679
@final
685-
def _post_plot_logic_common(self, ax, data):
680+
def _post_plot_logic_common(self, ax: Axes, data) -> None:
686681
"""Common post process for each axes"""
687682
if self.orientation == "vertical" or self.orientation is None:
688683
self._apply_axis_properties(ax.xaxis, rot=self.rot, fontsize=self.fontsize)
@@ -701,7 +696,7 @@ def _post_plot_logic_common(self, ax, data):
701696
raise ValueError
702697

703698
@abstractmethod
704-
def _post_plot_logic(self, ax, data) -> None:
699+
def _post_plot_logic(self, ax: Axes, data) -> None:
705700
"""Post process for each axes. Overridden in child classes"""
706701

707702
@final
@@ -1056,7 +1051,7 @@ def _get_colors(
10561051
)
10571052

10581053
@final
1059-
def _parse_errorbars(self, label, err):
1054+
def _parse_errorbars(self, label: str, err):
10601055
"""
10611056
Look for error keyword arguments and return the actual errorbar data
10621057
or return the error DataFrame/dict
@@ -1137,7 +1132,10 @@ def match_labels(data, e):
11371132
err = np.tile(err, (self.nseries, 1))
11381133

11391134
elif is_number(err):
1140-
err = np.tile([err], (self.nseries, len(self.data)))
1135+
err = np.tile(
1136+
[err], # pyright: ignore[reportGeneralTypeIssues]
1137+
(self.nseries, len(self.data)),
1138+
)
11411139

11421140
else:
11431141
msg = f"No valid {label} detected"
@@ -1418,14 +1416,14 @@ def _make_plot(self, fig: Figure) -> None:
14181416

14191417
x = data.index # dummy, not used
14201418
plotf = self._ts_plot
1421-
it = self._iter_data(data=data, keep_index=True)
1419+
it = data.items()
14221420
else:
14231421
x = self._get_xticks(convert_period=True)
14241422
# error: Incompatible types in assignment (expression has type
14251423
# "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
14261424
# type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
14271425
plotf = self._plot # type: ignore[assignment]
1428-
it = self._iter_data()
1426+
it = self._iter_data(data=self.data)
14291427

14301428
stacking_id = self._get_stacking_id()
14311429
is_errorbar = com.any_not_none(*self.errors.values())
@@ -1434,7 +1432,12 @@ def _make_plot(self, fig: Figure) -> None:
14341432
for i, (label, y) in enumerate(it):
14351433
ax = self._get_ax(i)
14361434
kwds = self.kwds.copy()
1437-
style, kwds = self._apply_style_colors(colors, kwds, i, label)
1435+
style, kwds = self._apply_style_colors(
1436+
colors,
1437+
kwds,
1438+
i,
1439+
label, # pyright: ignore[reportGeneralTypeIssues]
1440+
)
14381441

14391442
errors = self._get_errorbars(label=label, index=i)
14401443
kwds = dict(kwds, **errors)
@@ -1446,7 +1449,7 @@ def _make_plot(self, fig: Figure) -> None:
14461449
newlines = plotf(
14471450
ax,
14481451
x,
1449-
y,
1452+
y, # pyright: ignore[reportGeneralTypeIssues]
14501453
style=style,
14511454
column_num=i,
14521455
stacking_id=stacking_id,
@@ -1465,7 +1468,14 @@ def _make_plot(self, fig: Figure) -> None:
14651468
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
14661469
@classmethod
14671470
def _plot( # type: ignore[override]
1468-
cls, ax: Axes, x, y, style=None, column_num=None, stacking_id=None, **kwds
1471+
cls,
1472+
ax: Axes,
1473+
x,
1474+
y: np.ndarray,
1475+
style=None,
1476+
column_num=None,
1477+
stacking_id=None,
1478+
**kwds,
14691479
):
14701480
# column_num is used to get the target column from plotf in line and
14711481
# area plots
@@ -1492,7 +1502,7 @@ def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds):
14921502
decorate_axes(ax.right_ax, freq, kwds)
14931503
ax._plot_data.append((data, self._kind, kwds))
14941504

1495-
lines = self._plot(ax, data.index, data.values, style=style, **kwds)
1505+
lines = self._plot(ax, data.index, np.asarray(data.values), style=style, **kwds)
14961506
# set date formatter, locators and rescale limits
14971507
# error: Argument 3 to "format_dateaxis" has incompatible type "Index";
14981508
# expected "DatetimeIndex | PeriodIndex"
@@ -1520,7 +1530,9 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
15201530

15211531
@final
15221532
@classmethod
1523-
def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
1533+
def _get_stacked_values(
1534+
cls, ax: Axes, stacking_id: int | None, values: np.ndarray, label
1535+
) -> np.ndarray:
15241536
if stacking_id is None:
15251537
return values
15261538
if not hasattr(ax, "_stacker_pos_prior"):
@@ -1540,7 +1552,7 @@ def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
15401552

15411553
@final
15421554
@classmethod
1543-
def _update_stacker(cls, ax: Axes, stacking_id, values) -> None:
1555+
def _update_stacker(cls, ax: Axes, stacking_id: int | None, values) -> None:
15441556
if stacking_id is None:
15451557
return
15461558
if (values >= 0).all():
@@ -1618,7 +1630,7 @@ def _plot( # type: ignore[override]
16181630
cls,
16191631
ax: Axes,
16201632
x,
1621-
y,
1633+
y: np.ndarray,
16221634
style=None,
16231635
column_num=None,
16241636
stacking_id=None,
@@ -1744,7 +1756,7 @@ def _plot( # type: ignore[override]
17441756
cls,
17451757
ax: Axes,
17461758
x,
1747-
y,
1759+
y: np.ndarray,
17481760
w,
17491761
start: int | npt.NDArray[np.intp] = 0,
17501762
log: bool = False,
@@ -1763,7 +1775,8 @@ def _make_plot(self, fig: Figure) -> None:
17631775
pos_prior = neg_prior = np.zeros(len(self.data))
17641776
K = self.nseries
17651777

1766-
for i, (label, y) in enumerate(self._iter_data(fillna=0)):
1778+
data = self.data.fillna(0)
1779+
for i, (label, y) in enumerate(self._iter_data(data=data)):
17671780
ax = self._get_ax(i)
17681781
kwds = self.kwds.copy()
17691782
if self._is_series:
@@ -1842,7 +1855,14 @@ def _post_plot_logic(self, ax: Axes, data) -> None:
18421855

18431856
self._decorate_ticks(ax, self._get_index_name(), str_index, s_edge, e_edge)
18441857

1845-
def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
1858+
def _decorate_ticks(
1859+
self,
1860+
ax: Axes,
1861+
name: str | None,
1862+
ticklabels: list[str],
1863+
start_edge: float,
1864+
end_edge: float,
1865+
) -> None:
18461866
ax.set_xlim((start_edge, end_edge))
18471867

18481868
if self.xticks is not None:
@@ -1876,7 +1896,7 @@ def _plot( # type: ignore[override]
18761896
cls,
18771897
ax: Axes,
18781898
x,
1879-
y,
1899+
y: np.ndarray,
18801900
w,
18811901
start: int | npt.NDArray[np.intp] = 0,
18821902
log: bool = False,
@@ -1887,7 +1907,14 @@ def _plot( # type: ignore[override]
18871907
def _get_custom_index_name(self):
18881908
return self.ylabel
18891909

1890-
def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
1910+
def _decorate_ticks(
1911+
self,
1912+
ax: Axes,
1913+
name: str | None,
1914+
ticklabels: list[str],
1915+
start_edge: float,
1916+
end_edge: float,
1917+
) -> None:
18911918
# horizontal bars
18921919
ax.set_ylim((start_edge, end_edge))
18931920
ax.set_yticks(self.tick_pos)
@@ -1921,7 +1948,7 @@ def _make_plot(self, fig: Figure) -> None:
19211948
colors = self._get_colors(num_colors=len(self.data), color_kwds="colors")
19221949
self.kwds.setdefault("colors", colors)
19231950

1924-
for i, (label, y) in enumerate(self._iter_data()):
1951+
for i, (label, y) in enumerate(self._iter_data(data=self.data)):
19251952
ax = self._get_ax(i)
19261953
if label is not None:
19271954
label = pprint_thing(label)

pandas/plotting/_matplotlib/groupby.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
from pandas.plotting._matplotlib.misc import unpack_single_str_list
1717

1818
if TYPE_CHECKING:
19+
from collections.abc import Hashable
20+
1921
from pandas._typing import IndexLabel
2022

2123

2224
def create_iter_data_given_by(
2325
data: DataFrame, kind: str = "hist"
24-
) -> dict[str, DataFrame | Series]:
26+
) -> dict[Hashable, DataFrame | Series]:
2527
"""
2628
Create data for iteration given `by` is assigned or not, and it is only
2729
used in both hist and boxplot.
@@ -126,9 +128,7 @@ def reconstruct_data_with_by(
126128
return data
127129

128130

129-
def reformat_hist_y_given_by(
130-
y: Series | np.ndarray, by: IndexLabel | None
131-
) -> Series | np.ndarray:
131+
def reformat_hist_y_given_by(y: np.ndarray, by: IndexLabel | None) -> np.ndarray:
132132
"""Internal function to reformat y given `by` is applied or not for hist plot.
133133
134134
If by is None, input y is 1-d with NaN removed; and if by is not None, groupby

0 commit comments

Comments
 (0)